add glow-tts model and layers

This commit is contained in:
erogol 2020-08-10 12:57:15 +02:00
parent 43771a3a5c
commit 383c5f7185
15 changed files with 26738 additions and 0 deletions

View File

@ -0,0 +1,651 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import glob
import os
import sys
import time
import traceback
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from mozilla_voice_tts.tts.datasets.preprocess import load_meta_data
from mozilla_voice_tts.tts.datasets.TTSDataset import MyDataset
from mozilla_voice_tts.tts.layers.losses import GlowTTSLoss
from mozilla_voice_tts.utils.console_logger import ConsoleLogger
from mozilla_voice_tts.tts.utils.distribute import (DistributedSampler,
init_distributed, reduce_tensor)
from mozilla_voice_tts.tts.utils.generic_utils import check_config, setup_model
from mozilla_voice_tts.tts.utils.io import save_best_model, save_checkpoint
from mozilla_voice_tts.tts.utils.measures import alignment_diagonal_score
from mozilla_voice_tts.tts.utils.speakers import (get_speakers, load_speaker_mapping,
save_speaker_mapping)
from mozilla_voice_tts.tts.utils.synthesis import synthesis
from mozilla_voice_tts.tts.utils.text.symbols import make_symbols, phonemes, symbols
from mozilla_voice_tts.tts.utils.visual import plot_alignment, plot_spectrogram
from mozilla_voice_tts.utils.audio import AudioProcessor
from mozilla_voice_tts.utils.generic_utils import (KeepAverage, count_parameters,
create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict)
from mozilla_voice_tts.utils.io import copy_config_file, load_config
from mozilla_voice_tts.utils.radam import RAdam
from mozilla_voice_tts.utils.tensorboard_logger import TensorboardLogger
from mozilla_voice_tts.utils.training import (NoamLR, adam_weight_decay, check_update,
gradual_training_scheduler, set_weight_decay,
setup_torch_training_env)
use_cuda, num_gpus = setup_torch_training_env(True, False)
def setup_loader(ap, r, is_val=False, verbose=False):
if is_val and not c.run_eval:
loader = None
else:
dataset = MyDataset(
r,
c.text_cleaner,
compute_linear_spec=True if c.model.lower() == 'tacotron' else False,
meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap,
tp=c.characters if 'characters' in c.keys() else None,
batch_group_size=0 if is_val else c.batch_group_size *
c.batch_size,
min_seq_len=c.min_seq_len,
max_seq_len=c.max_seq_len,
phoneme_cache_path=c.phoneme_cache_path,
use_phonemes=c.use_phonemes,
phoneme_language=c.phoneme_language,
enable_eos_bos=c.enable_eos_bos_chars,
verbose=verbose)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(
dataset,
batch_size=c.eval_batch_size if is_val else c.batch_size,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
sampler=sampler,
num_workers=c.num_val_loader_workers
if is_val else c.num_loader_workers,
pin_memory=False)
return loader
def format_data(data):
if c.use_speaker_embedding:
speaker_mapping = load_speaker_mapping(OUT_PATH)
# setup input data
text_input = data[0]
text_lengths = data[1]
speaker_names = data[2]
mel_input = data[4].permute(0, 2, 1) # B x D x T
mel_lengths = data[5]
attn_mask = data[8]
avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float())
if c.use_speaker_embedding:
speaker_ids = [
speaker_mapping[speaker_name] for speaker_name in speaker_names
]
speaker_ids = torch.LongTensor(speaker_ids)
else:
speaker_ids = None
# dispatch data to GPU
if use_cuda:
text_input = text_input.cuda(non_blocking=True)
text_lengths = text_lengths.cuda(non_blocking=True)
mel_input = mel_input.cuda(non_blocking=True)
mel_lengths = mel_lengths.cuda(non_blocking=True)
if speaker_ids is not None:
speaker_ids = speaker_ids.cuda(non_blocking=True)
if attn_mask is not None:
attn_mask = attn_mask.cuda(non_blocking=True)
return text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
avg_text_length, avg_spec_length, attn_mask
def data_depended_init(model, ap):
"""Data depended initialization for normalization layers."""
if hasattr(model, 'module'):
for f in model.module.decoder.flows:
if getattr(f, "set_ddi", False):
f.set_ddi(True)
else:
for f in model.decoder.flows:
if getattr(f, "set_ddi", False):
f.set_ddi(True)
data_loader = setup_loader(ap, 1, is_val=False)
model.train()
print(" > Data depended initialization ... ")
with torch.no_grad():
for num_iter, data in enumerate(data_loader):
# format data
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
avg_text_length, avg_spec_length, attn_mask = format_data(data)
# forward pass model
_ = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask)
break
if hasattr(model, 'module'):
for f in model.module.decoder.flows:
if getattr(f, "set_ddi", False):
f.set_ddi(False)
else:
for f in model.decoder.flows:
if getattr(f, "set_ddi", False):
f.set_ddi(False)
return model
def train(model, criterion, optimizer, scheduler,
ap, global_step, epoch, amp):
data_loader = setup_loader(ap, 1, is_val=False,
verbose=(epoch == 0))
model.train()
epoch_time = 0
keep_avg = KeepAverage()
if use_cuda:
batch_n_iter = int(
len(data_loader.dataset) / (c.batch_size * num_gpus))
else:
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time()
c_logger.print_train_start()
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# format data
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
avg_text_length, avg_spec_length, attn_mask = format_data(data)
loader_time = time.time() - end_time
global_step += 1
# setup lr
if c.noam_schedule:
scheduler.step()
optimizer.zero_grad()
# forward pass model
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask)
# compute loss
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
o_dur_log, o_total_dur, text_lengths)
# backward pass
if amp is not None:
with amp.scale_loss( loss_dict['loss'], optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss_dict['loss'].backward()
if amp:
amp_opt_params = amp.master_params(optimizer)
else:
amp_opt_params = None
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params)
optimizer.step()
# current_lr
current_lr = optimizer.param_groups[0]['lr']
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(alignments)
loss_dict['align_error'] = align_error
step_time = time.time() - start_time
epoch_time += step_time
# aggregate losses from processes
if num_gpus > 1:
loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus)
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
# detach loss values
loss_dict_new = dict()
for key, value in loss_dict.items():
if isinstance(value, (int, float)):
loss_dict_new[key] = value
else:
loss_dict_new[key] = value.item()
loss_dict = loss_dict_new
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values['avg_' + key] = value
update_train_values['avg_loader_time'] = loader_time
update_train_values['avg_step_time'] = step_time
keep_avg.update_values(update_train_values)
# print training progress
if global_step % c.print_step == 0:
log_dict = {
"avg_spec_length": [avg_spec_length, 1], # value, precision
"avg_text_length": [avg_text_length, 1],
"step_time": [step_time, 4],
"loader_time": [loader_time, 2],
"current_lr": current_lr,
}
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
log_dict, loss_dict, keep_avg.avg_values)
if args.rank == 0:
# Plot Training Iter Stats
# reduce TB load
if global_step % c.tb_plot_step == 0:
iter_stats = {
"lr": current_lr,
"grad_norm": grad_norm,
"step_time": step_time
}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
model_loss=loss_dict['loss'],
amp_state_dict=amp.state_dict() if amp else None)
# Diagnostic visualizations
# direct pass on model for spec predictions
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1])
spec_pred = spec_pred.permute(0, 2, 1)
gt_spec = mel_input.permute(0, 2, 1)
const_spec = spec_pred[0].data.cpu().numpy()
gt_spec = gt_spec[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy()
figures = {
"prediction": plot_spectrogram(const_spec, ap),
"ground_truth": plot_spectrogram(gt_spec, ap),
"alignment": plot_alignment(align_img),
}
tb_logger.tb_train_figures(global_step, figures)
# Sample audio
train_audio = ap.inv_melspectrogram(const_spec.T)
tb_logger.tb_train_audios(global_step,
{'TrainAudio': train_audio},
c.audio["sample_rate"])
end_time = time.time()
# print epoch stats
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
# Plot Epoch Stats
if args.rank == 0:
epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(keep_avg.avg_values)
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
if c.tb_model_param_stats:
tb_logger.tb_model_weights(model, global_step)
return keep_avg.avg_values, global_step
@torch.no_grad()
def evaluate(model, criterion, ap, global_step, epoch):
data_loader = setup_loader(ap, 1, is_val=True)
model.eval()
epoch_time = 0
keep_avg = KeepAverage()
c_logger.print_eval_start()
if data_loader is not None:
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# format data
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
avg_text_length, avg_spec_length, attn_mask = format_data(data)
# forward pass model
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask)
# compute loss
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
o_dur_log, o_total_dur, text_lengths)
# step time
step_time = time.time() - start_time
epoch_time += step_time
# compute alignment score
align_error = 1 - alignment_diagonal_score(alignments)
loss_dict['align_error'] = align_error
# aggregate losses from processes
if num_gpus > 1:
loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus)
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
# detach loss values
loss_dict_new = dict()
for key, value in loss_dict.items():
if isinstance(value, (int, float)):
loss_dict_new[key] = value
else:
loss_dict_new[key] = value.item()
loss_dict = loss_dict_new
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values['avg_' + key] = value
keep_avg.update_values(update_train_values)
if c.print_eval:
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
if args.rank == 0:
# Diagnostic visualizations
# direct pass on model for spec predictions
if hasattr(model, 'module'):
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1])
else:
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1])
spec_pred = spec_pred.permute(0, 2, 1)
gt_spec = mel_input.permute(0, 2, 1)
const_spec = spec_pred[0].data.cpu().numpy()
gt_spec = gt_spec[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy()
eval_figures = {
"prediction": plot_spectrogram(const_spec, ap),
"ground_truth": plot_spectrogram(gt_spec, ap),
"alignment": plot_alignment(align_img)
}
# Sample audio
eval_audio = ap.inv_melspectrogram(const_spec.T)
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
c.audio["sample_rate"])
# Plot Validation Stats
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
tb_logger.tb_eval_figures(global_step, eval_figures)
if args.rank == 0 and epoch > c.test_delay_epochs:
if c.test_sentences_file is None:
test_sentences = [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963."
]
else:
with open(c.test_sentences_file, "r") as f:
test_sentences = [s.strip() for s in f.readlines()]
# test sentences
test_audios = {}
test_figures = {}
print(" | > Synthesizing test sentences")
speaker_id = 0 if c.use_speaker_embedding else None
style_wav = c.get("style_wav_for_test")
for idx, test_sentence in enumerate(test_sentences):
try:
wav, alignment, decoder_output, postnet_output, stop_tokens, inputs = synthesis(
model,
test_sentence,
c,
use_cuda,
ap,
speaker_id=speaker_id,
style_wav=style_wav,
truncated=False,
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
use_griffin_lim=True,
do_trim_silence=False)
file_path = os.path.join(AUDIO_PATH, str(global_step))
os.makedirs(file_path, exist_ok=True)
file_path = os.path.join(file_path,
"TestSentence_{}.wav".format(idx))
ap.save_wav(wav, file_path)
test_audios['{}-audio'.format(idx)] = wav
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
postnet_output, ap)
test_figures['{}-alignment'.format(idx)] = plot_alignment(
alignment)
except:
print(" !! Error creating Test Sentence -", idx)
traceback.print_exc()
tb_logger.tb_test_audios(global_step, test_audios,
c.audio['sample_rate'])
tb_logger.tb_test_figures(global_step, test_figures)
return keep_avg.avg_values
# FIXME: move args definition/parsing inside of main?
def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global meta_data_train, meta_data_eval, symbols, phonemes
# Audio processor
ap = AudioProcessor(**c.audio)
if 'characters' in c.keys():
symbols, phonemes = make_symbols(**c.characters)
# DISTRUBUTED
if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
# load data instances
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
# set the portion of the data used for training
if 'train_portion' in c.keys():
meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
if 'eval_portion' in c.keys():
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
# parse speakers
if c.use_speaker_embedding:
speakers = get_speakers(meta_data_train)
if args.restore_path:
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
assert all([speaker in speaker_mapping
for speaker in speakers]), "As of now you, you cannot " \
"introduce new speakers to " \
"a previously trained model."
else:
speaker_mapping = {name: i for i, name in enumerate(speakers)}
save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping)
print("Training with {} speakers: {}".format(num_speakers,
", ".join(speakers)))
else:
num_speakers = 0
# setup model
model = setup_model(num_chars, num_speakers, c)
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
criterion = GlowTTSLoss()
if c.apex_amp_level:
# pylint: disable=import-outside-toplevel
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
model.cuda()
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
else:
amp = None
if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location='cpu')
try:
# TODO: fix optimizer init, model.cuda() needs to be called before
# optimizer restore
optimizer.load_state_dict(checkpoint['optimizer'])
if c.reinit_layers:
raise RuntimeError
model.load_state_dict(checkpoint['model'])
except:
print(" > Partial model initialization.")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
model.load_state_dict(model_dict)
del model_dict
if amp and 'amp' in checkpoint:
amp.load_state_dict(checkpoint['amp'])
for group in optimizer.param_groups:
group['initial_lr'] = c.lr
print(" > Model restored from step %d" % checkpoint['step'],
flush=True)
args.restore_step = checkpoint['step']
else:
args.restore_step = 0
if use_cuda:
model.cuda()
criterion.cuda()
# DISTRUBUTED
if num_gpus > 1:
model = DDP(model)
if c.noam_schedule:
scheduler = NoamLR(optimizer,
warmup_steps=c.warmup_steps,
last_epoch=args.restore_step - 1)
else:
scheduler = None
num_params = count_parameters(model)
print("\n > Model has {} parameters".format(num_params), flush=True)
if 'best_loss' not in locals():
best_loss = float('inf')
global_step = args.restore_step
model = data_depended_init(model, ap)
for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs)
# train_avg_loss_dict, global_step = train(model, criterion, optimizer,
# scheduler, ap, global_step,
# epoch, amp)
train_avg_loss_dict, global_step = 0, 0
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_loss']
if c.run_eval:
target_loss = eval_avg_loss_dict['avg_loss']
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
OUT_PATH, amp_state_dict=amp.state_dict() if amp else None)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--continue_path',
type=str,
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default='',
required='--config_path' not in sys.argv)
parser.add_argument(
'--restore_path',
type=str,
help='Model file to be restored. Use to finetune a model.',
default='')
parser.add_argument(
'--config_path',
type=str,
help='Path to config file for training.',
required='--continue_path' not in sys.argv
)
parser.add_argument('--debug',
type=bool,
default=False,
help='Do not verify commit integrity to run training.')
# DISTRUBUTED
parser.add_argument(
'--rank',
type=int,
default=0,
help='DISTRIBUTED: process rank for distributed training.')
parser.add_argument('--group_id',
type=str,
default="",
help='DISTRIBUTED: process group id.')
args = parser.parse_args()
if args.continue_path != '':
args.output_path = args.continue_path
args.config_path = os.path.join(args.continue_path, 'config.json')
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
latest_model_file = max(list_of_files, key=os.path.getctime)
args.restore_path = latest_model_file
print(f" > Training continues for {args.restore_path}")
# setup output paths and read configs
c = load_config(args.config_path)
# check_config(c)
_ = os.path.dirname(os.path.realpath(__file__))
if c.apex_amp_level:
print(" > apex AMP level: ", c.apex_amp_level)
OUT_PATH = args.continue_path
if args.continue_path == '':
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
c_logger = ConsoleLogger()
if args.rank == 0:
os.makedirs(AUDIO_PATH, exist_ok=True)
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
copy_config_file(args.config_path,
os.path.join(OUT_PATH, 'config.json'), new_fields)
os.chmod(AUDIO_PATH, 0o775)
os.chmod(OUT_PATH, 0o775)
LOG_DIR = OUT_PATH
tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
# write model desc to tensorboard
tb_logger.tb_add_text('model-description', c['run_description'], 0)
try:
main(args)
except KeyboardInterrupt:
remove_experiment_folder(OUT_PATH)
try:
sys.exit(0)
except SystemExit:
os._exit(0) # pylint: disable=protected-access
except Exception: # pylint: disable=broad-except
remove_experiment_folder(OUT_PATH)
traceback.print_exc()
sys.exit(1)

View File

@ -0,0 +1,202 @@
import copy
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
class MultiHeadAttention(nn.Module):
def __init__(self,
channels,
out_channels,
n_heads,
window_size=None,
heads_share=True,
dropout_p=0.,
input_length=None,
proximal_bias=False,
proximal_init=False):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.window_size = window_size
self.heads_share = heads_share
self.input_length = input_length
self.proximal_bias = proximal_bias
self.dropout_p = dropout_p
self.attn = None
self.k_channels = channels // n_heads
self.conv_q = nn.Conv1d(channels, channels, 1)
self.conv_k = nn.Conv1d(channels, channels, 1)
self.conv_v = nn.Conv1d(channels, channels, 1)
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev)
self.emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev)
self.conv_o = nn.Conv1d(channels, out_channels, 1)
self.drop = nn.Dropout(dropout_p)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
if proximal_init:
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
nn.init.xavier_uniform_(self.conv_v.weight)
def forward(self, x, c, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, t_t = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels,
t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels,
t_s).transpose(2, 3)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
self.k_channels)
if self.window_size is not None:
assert t_s == t_t, "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(
self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(
query, key_relative_embeddings)
rel_logits = self._relative_position_to_absolute_position(
rel_logits)
scores_local = rel_logits / math.sqrt(self.k_channels)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(
device=scores.device, dtype=scores.dtype)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.input_length is not None:
block_mask = torch.ones_like(scores).triu(
-self.input_length).tril(self.input_length)
scores = scores * block_mask + -1e4 * (1 - block_mask)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(
p_attn)
value_relative_embeddings = self._get_relative_embeddings(
self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(
relative_weights, value_relative_embeddings)
output = output.transpose(2, 3).contiguous().view(
b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn
def _matmul_with_relative_values(self, x, y):
"""
x: [b, h, l, m]
y: [h or 1, m, d]
ret: [b, h, l, d]
"""
ret = torch.matmul(x, y.unsqueeze(0))
return ret
def _matmul_with_relative_keys(self, x, y):
"""
x: [b, h, l, d]
y: [h or 1, m, d]
ret: [b, h, l, m]
"""
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret
def _get_relative_embeddings(self, relative_embeddings, length):
max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = F.pad(
relative_embeddings,
convert_pad_shape([[0, 0], [pad_length, pad_length],
[0, 0]]))
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[:,
slice_start_position:
slice_end_position]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
"""
x: [b, h, l, 2*l-1]
ret: [b, h, l, l]
"""
batch, heads, length, _ = x.size()
# Concat columns of pad to shift from relative to absolute indexing.
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0,
1]]))
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(
x_flat, convert_pad_shape([[0, 0], [0, 0], [0,
length - 1]]))
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1,
2 * length - 1])[:, :, :length, length - 1:]
return x_final
def _absolute_position_to_relative_position(self, x):
"""
x: [b, h, l, l]
ret: [b, h, l, 2*l-1]
"""
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(
x,
convert_pad_shape([[0, 0], [0, 0], [0, 0], [0,
length - 1]]))
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(
x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final
def _attention_bias_proximal(self, length):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
Returns:
a Tensor with shape [1, 1, length, length]
"""
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(
torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)

View File

@ -0,0 +1,63 @@
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def maximum_path(value, mask, max_neg_val=-np.inf):
""" Numpy-friendly version. It's about 4 times faster than torch version.
value: [b, t_x, t_y]
mask: [b, t_x, t_y]
"""
value = value * mask
device = value.device
dtype = value.dtype
value = value.cpu().detach().numpy()
mask = mask.cpu().detach().numpy().astype(np.bool)
b, t_x, t_y = value.shape
direction = np.zeros(value.shape, dtype=np.int64)
v = np.zeros((b, t_x), dtype=np.float32)
x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1)
for j in range(t_y):
v0 = np.pad(v, [[0, 0], [1, 0]],
mode="constant",
constant_values=max_neg_val)[:, :-1]
v1 = v
max_mask = (v1 >= v0)
v_max = np.where(max_mask, v1, v0)
direction[:, :, j] = max_mask
index_mask = (x_range <= j)
v = np.where(index_mask, v_max + value[:, :, j], max_neg_val)
direction = np.where(mask, direction, 1)
path = np.zeros(value.shape, dtype=np.float32)
index = mask[:, :, 0].sum(1).astype(np.int64) - 1
index_range = np.arange(b)
for j in reversed(range(t_y)):
path[index_range, index, j] = 1
index = index + direction[index_range, index, j] - 1
path = path * mask.astype(np.float32)
path = torch.from_numpy(path).to(device=device, dtype=dtype)
return path

View File

@ -0,0 +1,110 @@
import torch
from torch import nn
from torch.nn import functional as F
from mozilla_voice_tts.tts.utils.generic_utils import sequence_mask
from mozilla_voice_tts.tts.layers.glow_tts.glow import InvConvNear, CouplingBlock, ActNorm
def squeeze(x, x_mask=None, num_sqz=2):
b, c, t = x.size()
t = (t // num_sqz) * num_sqz
x = x[:, :, :t]
x_sqz = x.view(b, c, t // num_sqz, num_sqz)
x_sqz = x_sqz.permute(0, 3, 1,
2).contiguous().view(b, c * num_sqz, t // num_sqz)
if x_mask is not None:
x_mask = x_mask[:, :, num_sqz - 1::num_sqz]
else:
x_mask = torch.ones(b, 1, t // num_sqz).to(device=x.device,
dtype=x.dtype)
return x_sqz * x_mask, x_mask
def unsqueeze(x, x_mask=None, num_sqz=2):
b, c, t = x.size()
x_unsqz = x.view(b, num_sqz, c // num_sqz, t)
x_unsqz = x_unsqz.permute(0, 2, 3,
1).contiguous().view(b, c // num_sqz,
t * num_sqz)
if x_mask is not None:
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1,
num_sqz).view(b, 1, t * num_sqz)
else:
x_mask = torch.ones(b, 1, t * num_sqz).to(device=x.device,
dtype=x.dtype)
return x_unsqz * x_mask, x_mask
class Decoder(nn.Module):
"""Stack of Glow Modules"""
def __init__(self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_flow_blocks,
num_coupling_layers,
dropout_p=0.,
num_splits=4,
num_sqz=2,
sigmoid_scale=False,
c_in_channels=0,
feat_channels=None):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.num_flow_blocks = num_flow_blocks
self.num_coupling_layers = num_coupling_layers
self.dropout_p = dropout_p
self.num_splits = num_splits
self.num_sqz = num_sqz
self.sigmoid_scale = sigmoid_scale
self.c_in_channels = c_in_channels
self.flows = nn.ModuleList()
for _ in range(num_flow_blocks):
self.flows.append(ActNorm(channels=in_channels * num_sqz))
self.flows.append(
InvConvNear(channels=in_channels * num_sqz,
num_splits=num_splits))
self.flows.append(
CouplingBlock(in_channels * num_sqz,
hidden_channels,
kernel_size=kernel_size,
dilation_rate=dilation_rate,
num_layers=num_coupling_layers,
c_in_channels=c_in_channels,
dropout_p=dropout_p,
sigmoid_scale=sigmoid_scale))
def forward(self, x, x_mask, g=None, reverse=False):
if not reverse:
flows = self.flows
logdet_tot = 0
else:
flows = reversed(self.flows)
logdet_tot = None
if self.num_sqz > 1:
x, x_mask = squeeze(x, x_mask, self.num_sqz)
for f in flows:
if not reverse:
x, logdet = f(x, x_mask, g=g, reverse=reverse)
logdet_tot += logdet
else:
x, logdet = f(x, x_mask, g=g, reverse=reverse)
if self.num_sqz > 1:
x, x_mask = unsqueeze(x, x_mask, self.num_sqz)
return x, logdet_tot
def store_inverse(self):
for f in self.flows:
f.store_inverse()

View File

@ -0,0 +1,42 @@
import copy
import math
import numpy as np
import scipy
import torch
from torch import nn
from torch.nn import functional as F
class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, dropout_p):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.dropout_p = dropout_p
self.drop = nn.Dropout(dropout_p)
self.conv_1 = nn.Conv1d(in_channels,
filter_channels,
kernel_size,
padding=kernel_size // 2)
self.norm_1 = nn.GroupNorm(1, filter_channels)
self.conv_2 = nn.Conv1d(filter_channels,
filter_channels,
kernel_size,
padding=kernel_size // 2)
self.norm_2 = nn.GroupNorm(1, filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask

View File

@ -0,0 +1,138 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
from mozilla_voice_tts.tts.utils.generic_utils import sequence_mask
from mozilla_voice_tts.tts.layers.glow_tts.glow import ConvLayerNorm
from mozilla_voice_tts.tts.layers.glow_tts.duration_predictor import DurationPredictor
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
self.register_buffer('pe', self._get_pe_matrix(d_model, max_len))
def forward(self, x):
return x + self.pe[:x.size(0)].unsqueeze(1)
def _get_pe_matrix(self, d_model, max_len):
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.pow(10000,
torch.arange(0, d_model, 2).float() / d_model)
pe[:, 0::2] = torch.sin(position / div_term)
pe[:, 1::2] = torch.cos(position / div_term)
return pe
class Encoder(nn.Module):
"""Glow-TTS encoder module. We use Pytorch TransformerEncoder instead
of the one with relative position embedding. We use positional encoding
for capturing positiong information.
Args:
num_chars (int): number of characters.
out_channels (int): number of output channels.
hidden_channels (int): encoder's embedding size.
filter_channels (int): transformer's feed-forward channels.
num_head (int): number of attention heads in transformer.
num_layers (int): number of transformer encoder stack.
kernel_size (int): kernel size for conv layers and duration predictor.
dropout_p (float): dropout rate for any dropout layer.
mean_only (bool): if True, output only mean values and use constant std.
use_prenet (bool): if True, use pre-convolutional layers before transformer layers.
c_in_channels (int): number of channels in conditional input.
Shapes:
- input: (B, T, C)
"""
def __init__(self,
num_chars,
out_channels,
hidden_channels,
filter_channels,
filter_channels_dp,
num_heads,
num_layers,
kernel_size,
dropout_p,
mean_only=False,
use_prenet=False,
c_in_channels=0):
super().__init__()
self.num_chars = num_chars
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.filter_channels_dp = filter_channels_dp
self.num_heads = num_heads
self.num_layers = num_layers
self.kernel_size = kernel_size
self.dropout_p = dropout_p
self.mean_only = mean_only
self.use_prenet = use_prenet
self.c_in_channels = c_in_channels
self.emb = nn.Embedding(num_chars, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
self.register_buffer('pe', PositionalEncoding(hidden_channels).pe)
if use_prenet:
self.pre = ConvLayerNorm(hidden_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
num_layers=3,
dropout_p=0.5)
encoder = nn.TransformerEncoderLayer(hidden_channels,
num_heads,
dim_feedforward=filter_channels,
dropout=dropout_p)
self.encoder = nn.TransformerEncoder(encoder, num_layers)
self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1)
if not mean_only:
self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1)
self.duration_predictor = DurationPredictor(
hidden_channels + c_in_channels, filter_channels_dp, kernel_size,
dropout_p)
def forward(self, x, x_lengths, g=None):
# pass embedding layer
# [B ,T, D]
x = self.emb(x) * math.sqrt(self.hidden_channels)
x += self.pe[:x.shape[1]].unsqueeze(0)
# [B, D, T]
x = torch.transpose(x, 1, -1)
# compute input sequence mask
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)),
1).to(x.dtype)
# pass pre-conv layers
if self.use_prenet:
x = self.pre(x, x_mask)
# pass encoder
x = self.encoder(x.permute(2, 0, 1),
src_key_padding_mask=~(x_mask.squeeze(1).to(bool)))
x = x.permute(1, 2, 0)
# set duration predictor input
if g is not None:
g_exp = g.expand(-1, -1, x.size(-1))
x_dp = torch.cat([torch.detach(x), g_exp], 1)
else:
x_dp = torch.detach(x)
# pass final projection layer
x_m = self.proj_m(x) * x_mask
if not self.mean_only:
x_logs = self.proj_s(x) * x_mask
else:
x_logs = torch.zeros_like(x_m)
# pass duration predictor
logw = self.duration_predictor(x_dp, x_mask)
return x_m, x_logs, logw, x_mask

View File

@ -0,0 +1,358 @@
import torch
from torch import nn
from torch.nn import functional as F
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-4):
"""Layer norm for the 2nd dimension of the input.
Args:
channels (int): number of channels (2nd dimension) of the input.
eps (float): to prevent 0 division
Shapes:
- input: (B, C, T)
- output: (B, C, T)
"""
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(1, channels, 1) * 0.1)
self.beta = nn.Parameter(torch.zeros(1, channels, 1))
def forward(self, x):
mean = torch.mean(x, 1, keepdim=True)
variance = torch.mean((x - mean)**2, 1, keepdim=True)
x = (x - mean) * torch.rsqrt(variance + self.eps)
x = x * self.gamma + self.beta
return x
class ConvLayerNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
num_layers, dropout_p):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.num_layers = num_layers
self.dropout_p = dropout_p
assert num_layers > 1, " [!] number of layers should be > 0."
assert kernel_size % 2 == 1, " [!] kernel size should be odd number."
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(
nn.Conv1d(in_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
for _ in range(num_layers - 1):
self.conv_layers.append(
nn.Conv1d(hidden_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.num_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x * x_mask)
x = F.dropout(F.relu(x), self.dropout_p, training=self.training)
x = x_org + self.proj(x)
return x * x_mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
class WN(torch.nn.Module):
def __init__(self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
c_in_channels=0,
dropout_p=0):
super(WN, self).__init__()
assert kernel_size % 2 == 1
assert hidden_channels % 2 == 0
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.num_layers = num_layers
self.c_in_channels = c_in_channels
self.dropout_p = dropout_p
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.dropout = nn.Dropout(dropout_p)
if c_in_channels != 0:
cond_layer = torch.nn.Conv1d(c_in_channels,
2 * hidden_channels * num_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer,
name='weight')
for i in range(num_layers):
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(hidden_channels,
2 * hidden_channels,
kernel_size,
dilation=dilation,
padding=padding)
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
self.in_layers.append(in_layer)
if i < num_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels,
res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer,
name='weight')
self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
if g is not None:
g = self.cond_layer(g)
for i in range(self.num_layers):
x_in = self.in_layers[i](x)
x_in = self.dropout(x_in)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:,
cond_offset:cond_offset + 2 * self.hidden_channels, :]
else:
g_l = torch.zeros_like(x_in)
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l,
n_channels_tensor)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.num_layers - 1:
x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask
output = output + res_skip_acts[:, self.hidden_channels:, :]
else:
output = output + res_skip_acts
return output * x_mask
def remove_weight_norm(self):
if self.c_in_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
class ActNorm(nn.Module):
"""Activation Normalization bijector as an alternative to Batch Norm. It computes
mean and std from a sample data in advance and it uses these values
for normalization at training.
Args:
channels (int): input channels.
ddi (False): data depended initialization flag.
Shapes:
- inputs: (B, C, T)
- outputs: (B, C, T)
"""
def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument
super().__init__()
self.channels = channels
self.initialized = not ddi
self.logs = nn.Parameter(torch.zeros(1, channels, 1))
self.bias = nn.Parameter(torch.zeros(1, channels, 1))
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
if x_mask is None:
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device,
dtype=x.dtype)
x_len = torch.sum(x_mask, [1, 2])
if not self.initialized:
self.initialize(x, x_mask)
self.initialized = True
if reverse:
z = (x - self.bias) * torch.exp(-self.logs) * x_mask
logdet = None
else:
z = (self.bias + torch.exp(self.logs) * x) * x_mask
logdet = torch.sum(self.logs) * x_len # [b]
return z, logdet
def store_inverse(self):
pass
def set_ddi(self, ddi):
self.initialized = not ddi
def initialize(self, x, x_mask):
with torch.no_grad():
denom = torch.sum(x_mask, [0, 2])
m = torch.sum(x * x_mask, [0, 2]) / denom
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
v = m_sq - (m**2)
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(
dtype=self.bias.dtype)
logs_init = (-logs).view(*self.logs.shape).to(
dtype=self.logs.dtype)
self.bias.data.copy_(bias_init)
self.logs.data.copy_(logs_init)
class InvConvNear(nn.Module):
def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument
super().__init__()
assert num_splits % 2 == 0
self.channels = channels
self.num_splits = num_splits
self.no_jacobian = no_jacobian
self.weight_inv = None
w_init = torch.qr(
torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0]
if torch.det(w_init) < 0:
w_init[:, 0] = -1 * w_init[:, 0]
self.weight = nn.Parameter(w_init)
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
"""Split the input into groups of size self.num_splits and
perform 1x1 convolution separately. Cast 1x1 conv operation
to 2d by reshaping the input for efficienty.
"""
b, c, t = x.size()
assert c % self.num_splits == 0
if x_mask is None:
x_mask = 1
x_len = torch.ones((b, ), dtype=x.dtype, device=x.device) * t
else:
x_len = torch.sum(x_mask, [1, 2])
x = x.view(b, 2, c // self.num_splits, self.num_splits // 2, t)
x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.num_splits,
c // self.num_splits, t)
if reverse:
if self.weight_inv is not None:
weight = self.weight_inv
else:
weight = torch.inverse(
self.weight.float()).to(dtype=self.weight.dtype)
logdet = None
else:
weight = self.weight
if self.no_jacobian:
logdet = 0
else:
logdet = torch.logdet(
self.weight) * (c / self.num_splits) * x_len # [b]
weight = weight.view(self.num_splits, self.num_splits, 1, 1)
z = F.conv2d(x, weight)
z = z.view(b, 2, self.num_splits // 2, c // self.num_splits, t)
z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
return z, logdet
def store_inverse(self):
self.weight_inv = torch.inverse(
self.weight.float()).to(dtype=self.weight.dtype)
class CouplingBlock(nn.Module):
def __init__(self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
c_in_channels=0,
dropout_p=0,
sigmoid_scale=False):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.num_layers = num_layers
self.c_in_channels = c_in_channels
self.dropout_p = dropout_p
self.sigmoid_scale = sigmoid_scale
start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
start = torch.nn.utils.weight_norm(start)
self.start = start
# Initializing last layer to 0 makes the affine coupling layers
# do nothing at first. This helps with training stability
end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
end.weight.data.zero_()
end.bias.data.zero_()
self.end = end
self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate,
num_layers, c_in_channels, dropout_p)
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument
if x_mask is None:
x_mask = 1
x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
x = self.start(x_0) * x_mask
x = self.wn(x, x_mask, g)
out = self.end(x)
z_0 = x_0
m = out[:, :self.in_channels // 2, :]
logs = out[:, self.in_channels // 2:, :]
if self.sigmoid_scale:
logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
if reverse:
z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
logdet = None
else:
z_1 = (m + torch.exp(logs) * x_1) * x_mask
logdet = torch.sum(logs * x_mask, [1, 2])
z = torch.cat([z_0, z_1], 1)
return z, logdet
def store_inverse(self):
self.wn.remove_weight_norm()

View File

@ -0,0 +1,125 @@
import torch
from torch import nn
from torch.nn import functional as F
from mozilla_voice_tts.tts.utils.generic_utils import sequence_mask
from mozilla_voice_tts.tts.layers.glow_tts.glow import InvConvNear, CouplingBlock, ActNorm, BatchNorm
def squeeze(x, x_mask=None, num_sqz=2):
b, c, t = x.size()
t = (t // num_sqz) * num_sqz
x = x[:, :, :t]
x_sqz = x.view(b, c, t // num_sqz, num_sqz)
x_sqz = x_sqz.permute(0, 3, 1,
2).contiguous().view(b, c * num_sqz, t // num_sqz)
if x_mask is not None:
x_mask = x_mask[:, :, num_sqz - 1::num_sqz]
else:
x_mask = torch.ones(b, 1, t // num_sqz).to(device=x.device,
dtype=x.dtype)
return x_sqz * x_mask, x_mask
def unsqueeze(x, x_mask=None, num_sqz=2):
b, c, t = x.size()
x_unsqz = x.view(b, num_sqz, c // num_sqz, t)
x_unsqz = x_unsqz.permute(0, 2, 3,
1).contiguous().view(b, c // num_sqz,
t * num_sqz)
if x_mask is not None:
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1,
num_sqz).view(b, 1, t * num_sqz)
else:
x_mask = torch.ones(b, 1, t * num_sqz).to(device=x.device,
dtype=x.dtype)
return x_unsqz * x_mask, x_mask
class Decoder(nn.Module):
"""Stack of Glow Modules"""
def __init__(self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_blocks,
num_coupling_layers,
dropout_p=0.,
num_splits=4,
num_sqz=2,
sigmoid_scale=False,
c_in_channels=0):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.num_blocks = num_blocks
self.num_coupling_layers = num_coupling_layers
self.dropout_p = dropout_p
self.num_splits = num_splits
self.num_sqz = num_sqz
self.sigmoid_scale = sigmoid_scale
self.c_in_channels = c_in_channels
self.norm_layers = nn.ModuleList()
self.invconv_layers = nn.ModuleList()
self.coupling_layers = nn.ModuleList()
for _ in range(num_blocks):
self.norm_layers.append(ActNorm(in_channels=in_channels * num_sqz))
self.invconv_layers.append(
InvConvNear(channels=in_channels * num_sqz,
num_splits=num_splits))
self.coupling_layers.append(
CouplingBlock(in_channels * num_sqz,
hidden_channels,
kernel_size=kernel_size,
dilation_rate=dilation_rate,
num_layers=num_coupling_layers,
c_in_channels=c_in_channels,
dropout_p=dropout_p,
sigmoid_scale=sigmoid_scale))
def forward(self, x, x_mask, g=None, reverse=False):
if not reverse:
# norm_layers = self.norm_layers
invconv_layers = self.invconv_layers
coupling_layers = self.coupling_layers
logdet_tot = 0
else:
# norm_layers = reversed(self.norm_layers)
invconv_layers = self.invconv_layers[::-1]
coupling_layers = self.coupling_layers[::-1]
logdet_tot = None
if self.num_sqz > 1:
x, x_mask = squeeze(x, x_mask, self.num_sqz)
for idx in range(len(self.invconv_layers)):
if not reverse:
# x, log_det = norm_layers[idx](x, reverse)
# logdet_tot += log_det
x, log_det = invconv_layers[idx](x, x_mask, reverse)
logdet_tot += log_det
x, log_det = coupling_layers[idx](x, x_mask, g=g, reverse=reverse)
logdet_tot += log_det
else:
x, log_det = coupling_layers[idx](x, x_mask, g=g, reverse=reverse)
x, log_det = invconv_layers[idx](x, x_mask, reverse)
# x, logdet = norm_layers[idx](x, reverse)
if self.num_sqz > 1:
x, x_mask = unsqueeze(x, x_mask, self.num_sqz)
return x, logdet_tot
def store_inverse(self):
for f in self.flows:
f.store_inverse()

View File

@ -0,0 +1,49 @@
import numpy as np
import torch
from torch.nn import functional as F
from mozilla_voice_tts.tts.utils.generic_utils import sequence_mask
from mozilla_voice_tts.tts.layers.glow_tts.monotonic_align.core import maximum_path_c
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def generate_path(duration, mask):
"""
duration: [b, t_x]
mask: [b, t_x, t_y]
"""
device = duration.device
b, t_x, t_y = mask.shape
cum_duration = torch.cumsum(duration, 1)
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]
]))[:, :-1]
path = path * mask
return path
def maximum_path(value, mask):
""" Cython optimised version.
value: [b, t_x, t_y]
mask: [b, t_x, t_y]
"""
value = value * mask
device = value.device
dtype = value.dtype
value = value.data.cpu().numpy().astype(np.float32)
path = np.zeros_like(value).astype(np.int32)
mask = mask.data.cpu().numpy()
t_x_max = mask.sum(1)[:, 0].astype(np.int32)
t_y_max = mask.sum(2)[:, 0].astype(np.int32)
maximum_path_c(path, value, t_x_max, t_y_max)
return torch.from_numpy(path).to(device=device, dtype=dtype)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,45 @@
import numpy as np
cimport numpy as np
cimport cython
from cython.parallel import prange
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
cdef int x
cdef int y
cdef float v_prev
cdef float v_cur
cdef float tmp
cdef int index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[x, y-1]
if x == 0:
if y == 0:
v_prev = 0.
else:
v_prev = max_neg_val
else:
v_prev = value[x-1, y-1]
value[x, y] = max(v_cur, v_prev) + value[x, y]
for y in range(t_y - 1, -1, -1):
path[index, y] = 1
if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
index = index - 1
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
cdef int b = values.shape[0]
cdef int i
for i in prange(b, nogil=True):
maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)

View File

@ -0,0 +1,9 @@
from distutils.core import setup
from Cython.Build import cythonize
import numpy
setup(
name = 'monotonic_align',
ext_modules = cythonize("core.pyx"),
include_dirs=[numpy.get_include()]
)

View File

@ -0,0 +1,100 @@
import torch
from torch import nn
from torch.nn import functional as F
from mozilla_voice_tts.tts.layers.glow_tts.attention import MultiHeadAttention
class Transformer(nn.Module):
def __init__(self,
hidden_channels,
filter_channels,
num_heads,
num_layers,
kernel_size=1,
dropout_p=0.,
rel_attn_window_size=None,
input_length=None,
**kwargs):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.num_heads = num_heads
self.num_layers = num_layers
self.kernel_size = kernel_size
self.dropout_p = dropout_p
self.rel_attn_window_size = rel_attn_window_size
self.input_length = input_length
self.drop = nn.Dropout(dropout_p)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for _ in range(self.num_layers):
self.attn_layers.append(
MultiHeadAttention(hidden_channels,
hidden_channels,
num_heads,
window_size=rel_attn_window_size,
dropout_p=dropout_p,
input_length=input_length))
self.norm_layers_1.append(nn.GroupNorm(1, hidden_channels))
self.ffn_layers.append(
FFN(hidden_channels,
hidden_channels,
filter_channels,
kernel_size,
dropout_p=dropout_p))
self.norm_layers_2.append(nn.GroupNorm(1, hidden_channels))
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
for i in range(self.num_layers):
x = x * x_mask
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class FFN(nn.Module):
def __init__(self,
in_channels,
out_channels,
filter_channels,
kernel_size,
dropout_p=0.,
activation=None):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.dropout_p = dropout_p
self.activation = activation
self.conv_1 = nn.Conv1d(in_channels,
filter_channels,
kernel_size,
padding=kernel_size // 2)
self.conv_2 = nn.Conv1d(filter_channels,
out_channels,
kernel_size,
padding=kernel_size // 2)
self.drop = nn.Dropout(dropout_p)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
if self.activation == "gelu":
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
return x * x_mask

View File

@ -0,0 +1,178 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
from mozilla_voice_tts.tts.layers.glow_tts.encoder import Encoder
from mozilla_voice_tts.tts.layers.glow_tts.decoder import Decoder
from mozilla_voice_tts.tts.utils.generic_utils import sequence_mask
from mozilla_voice_tts.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path
class GlowTts(nn.Module):
"""Glow TTS models from https://arxiv.org/abs/2005.11129"""
def __init__(self,
num_chars,
hidden_channels,
filter_channels,
filter_channels_dp,
out_channels,
kernel_size=3,
num_heads=2,
num_layers_enc=6,
dropout_p=0.1,
num_flow_blocks_dec=12,
kernel_size_dec=5,
dilation_rate=5,
num_block_layers=4,
dropout_p_dec=0.,
num_speakers=0,
c_in_channels=0,
num_splits=4,
num_sqz=1,
sigmoid_scale=False,
mean_only=False,
hidden_channels_enc=None,
hidden_channels_dec=None,
use_encoder_prenet=False):
super().__init__()
self.num_chars = num_chars
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.filter_channels_dp = filter_channels_dp
self.out_channels = out_channels
self.kernel_size = kernel_size
self.num_heads = num_heads
self.num_layers_enc = num_layers_enc
self.dropout_p = dropout_p
self.num_flow_blocks_dec = num_flow_blocks_dec
self.kernel_size_dec = kernel_size_dec
self.dilation_rate = dilation_rate
self.num_block_layers = num_block_layers
self.dropout_p_dec = dropout_p_dec
self.num_speakers = num_speakers
self.c_in_channels = c_in_channels
self.num_splits = num_splits
self.num_sqz = num_sqz
self.sigmoid_scale = sigmoid_scale
self.mean_only = mean_only
self.hidden_channels_enc = hidden_channels_enc
self.hidden_channels_dec = hidden_channels_dec
self.use_encoder_prenet = use_encoder_prenet
self.noise_scale=0.66
self.length_scale=1.
self.encoder = Encoder(num_chars,
out_channels,
hidden_channels_enc or hidden_channels,
filter_channels,
filter_channels_dp,
num_heads,
num_layers_enc,
kernel_size,
dropout_p,
mean_only=mean_only,
use_prenet=True,
c_in_channels=c_in_channels)
self.decoder = Decoder(out_channels,
hidden_channels_dec or hidden_channels,
kernel_size_dec,
dilation_rate,
num_flow_blocks_dec,
num_block_layers,
dropout_p=dropout_p_dec,
num_splits=num_splits,
num_sqz=num_sqz,
sigmoid_scale=sigmoid_scale,
c_in_channels=c_in_channels)
if num_speakers > 1:
self.emb_g = nn.Embedding(num_speakers, c_in_channels)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
def compute_outputs(self, attn, o_mean, o_log_scale, x_mask):
# compute final values with the computed alignment
y_mean = torch.matmul(
attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
y_log_scale = torch.matmul(
attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(
1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
# compute total duration with adjustment
o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
return y_mean, y_log_scale, o_attn_dur
def forward(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
"""
x: B x T
x_lenghts: B
y: B x D x T
y_lengths: B
"""
y_max_length = y.size(2)
# norm speaker embeddings
if g is not None:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
# embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
# format feature vectors and feature vector lenghts
y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
# create masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
# decoder pass
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
# find the alignment path
with torch.no_grad():
o_scale = torch.exp(-2 * o_log_scale)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.matmul(o_scale.transpose(1,2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp3 = torch.matmul((o_mean * o_scale).transpose(1,2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
attn = maximum_path(logp,
attn_mask.squeeze(1)).unsqueeze(1).detach()
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
attn, o_mean, o_log_scale, x_mask)
attn = attn.squeeze(1).permute(0, 2, 1)
return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur
@torch.no_grad()
def inference(self, x, x_lengths, g=None):
if g is not None:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
# embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
# compute output durations
w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = None
# compute masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
# compute attention mask
attn = generate_path(w_ceil.squeeze(1),
attn_mask.squeeze(1)).unsqueeze(1)
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
attn, o_mean, o_log_scale, x_mask)
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) *
self.noise_scale) * y_mask
# decoder pass
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
attn = attn.squeeze(1).permute(0, 2, 1)
return y, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur
def preprocess(self, y, y_lengths, y_max_length, attn=None):
if y_max_length is not None:
y_max_length = (y_max_length // self.num_sqz) * self.num_sqz
y = y[:, :, :y_max_length]
if attn is not None:
attn = attn[:, :, :, :y_max_length]
y_lengths = (y_lengths // self.num_sqz) * self.num_sqz
return y, y_lengths, y_max_length, attn
def store_inverse(self):
self.decoder.store_inverse()