GlowTTS zero-shot TTS Support

This commit is contained in:
Edresson 2020-10-24 15:58:39 -03:00
parent b7f9ebd32b
commit 07345099ee
4 changed files with 91 additions and 46 deletions

View File

@ -9,6 +9,7 @@ import time
import traceback
import torch
from random import randrange
from torch.utils.data import DataLoader
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
@ -36,8 +37,7 @@ from TTS.utils.training import (NoamLR, check_update,
use_cuda, num_gpus = setup_torch_training_env(True, False)
def setup_loader(ap, r, is_val=False, verbose=False):
def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None):
if is_val and not c.run_eval:
loader = None
else:
@ -56,7 +56,8 @@ def setup_loader(ap, r, is_val=False, verbose=False):
use_phonemes=c.use_phonemes,
phoneme_language=c.phoneme_language,
enable_eos_bos=c.enable_eos_bos_chars,
verbose=verbose)
verbose=verbose,
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(
dataset,
@ -86,10 +87,13 @@ def format_data(data):
avg_spec_length = torch.mean(mel_lengths.float())
if c.use_speaker_embedding:
speaker_ids = [
speaker_mapping[speaker_name] for speaker_name in speaker_names
]
speaker_ids = torch.LongTensor(speaker_ids)
if c.use_external_speaker_embedding_file:
speaker_ids = data[8]
else:
speaker_ids = [
speaker_mapping[speaker_name] for speaker_name in speaker_names
]
speaker_ids = torch.LongTensor(speaker_ids)
else:
speaker_ids = None
@ -107,7 +111,7 @@ def format_data(data):
avg_text_length, avg_spec_length, attn_mask
def data_depended_init(model, ap):
def data_depended_init(model, ap, speaker_mapping=None):
"""Data depended initialization for activation normalization."""
if hasattr(model, 'module'):
for f in model.module.decoder.flows:
@ -118,19 +122,19 @@ def data_depended_init(model, ap):
if getattr(f, "set_ddi", False):
f.set_ddi(True)
data_loader = setup_loader(ap, 1, is_val=False)
data_loader = setup_loader(ap, 1, is_val=False, speaker_mapping=speaker_mapping)
model.train()
print(" > Data depended initialization ... ")
with torch.no_grad():
for _, data in enumerate(data_loader):
# format data
text_input, text_lengths, mel_input, mel_lengths, _,\
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
_, _, attn_mask = format_data(data)
# forward pass model
_ = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask)
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
break
if hasattr(model, 'module'):
@ -145,9 +149,9 @@ def data_depended_init(model, ap):
def train(model, criterion, optimizer, scheduler,
ap, global_step, epoch, amp):
ap, global_step, epoch, amp, speaker_mapping=None):
data_loader = setup_loader(ap, 1, is_val=False,
verbose=(epoch == 0))
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
model.train()
epoch_time = 0
keep_avg = KeepAverage()
@ -162,7 +166,7 @@ def train(model, criterion, optimizer, scheduler,
start_time = time.time()
# format data
text_input, text_lengths, mel_input, mel_lengths, _,\
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
avg_text_length, avg_spec_length, attn_mask = format_data(data)
loader_time = time.time() - end_time
@ -176,7 +180,7 @@ def train(model, criterion, optimizer, scheduler,
# forward pass model
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask)
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
# compute loss
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
@ -262,7 +266,7 @@ def train(model, criterion, optimizer, scheduler,
# Diagnostic visualizations
# direct pass on model for spec predictions
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1])
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
spec_pred = spec_pred.permute(0, 2, 1)
gt_spec = mel_input.permute(0, 2, 1)
const_spec = spec_pred[0].data.cpu().numpy()
@ -298,8 +302,8 @@ def train(model, criterion, optimizer, scheduler,
@torch.no_grad()
def evaluate(model, criterion, ap, global_step, epoch):
data_loader = setup_loader(ap, 1, is_val=True)
def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping):
data_loader = setup_loader(ap, 1, is_val=True, speaker_mapping=speaker_mapping)
model.eval()
epoch_time = 0
keep_avg = KeepAverage()
@ -309,12 +313,12 @@ def evaluate(model, criterion, ap, global_step, epoch):
start_time = time.time()
# format data
text_input, text_lengths, mel_input, mel_lengths, _,\
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
_, _, attn_mask = format_data(data)
# forward pass model
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask)
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
# compute loss
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
@ -356,9 +360,9 @@ def evaluate(model, criterion, ap, global_step, epoch):
# Diagnostic visualizations
# direct pass on model for spec predictions
if hasattr(model, 'module'):
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1])
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
else:
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1])
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
spec_pred = spec_pred.permute(0, 2, 1)
gt_spec = mel_input.permute(0, 2, 1)
@ -398,7 +402,17 @@ def evaluate(model, criterion, ap, global_step, epoch):
test_audios = {}
test_figures = {}
print(" | > Synthesizing test sentences")
speaker_id = 0 if c.use_speaker_embedding else None
if c.use_speaker_embedding:
if c.use_external_speaker_embedding_file:
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding']
speaker_id = None
else:
speaker_id = 0
speaker_embedding = None
else:
speaker_id = None
speaker_embedding = None
style_wav = c.get("style_wav_for_test")
for idx, test_sentence in enumerate(test_sentences):
try:
@ -409,6 +423,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
use_cuda,
ap,
speaker_id=speaker_id,
speaker_embedding=speaker_embedding,
style_wav=style_wav,
truncated=False,
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
@ -462,23 +477,42 @@ def main(args): # pylint: disable=redefined-outer-name
if c.use_speaker_embedding:
speakers = get_speakers(meta_data_train)
if args.restore_path:
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
assert all([speaker in speaker_mapping
for speaker in speakers]), "As of now you, you cannot " \
"introduce new speakers to " \
"a previously trained model."
else:
if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
if not speaker_mapping:
print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file")
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
if not speaker_mapping:
raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file")
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
speaker_embedding_dim = None
assert all([speaker in speaker_mapping
for speaker in speakers]), "As of now you, you cannot " \
"introduce new speakers to " \
"a previously trained model."
elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file
raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
else: # if start new train and don't use External Embedding file
speaker_mapping = {name: i for i, name in enumerate(speakers)}
speaker_embedding_dim = None
save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping)
print("Training with {} speakers: {}".format(num_speakers,
print("Training with {} speakers: {}".format(len(speakers),
", ".join(speakers)))
else:
num_speakers = 0
speaker_embedding_dim = None
speaker_mapping = None
# setup model
model = setup_model(num_chars, num_speakers, c)
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim)
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
criterion = GlowTTSLoss()
@ -540,13 +574,13 @@ def main(args): # pylint: disable=redefined-outer-name
best_loss = float('inf')
global_step = args.restore_step
model = data_depended_init(model, ap)
model = data_depended_init(model, ap, speaker_mapping)
for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs)
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
scheduler, ap, global_step,
epoch, amp)
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
epoch, amp, speaker_mapping)
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=speaker_mapping)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_loss']
if c.run_eval:

View File

@ -68,13 +68,14 @@ class GlowTts(nn.Module):
self.use_encoder_prenet = use_encoder_prenet
self.noise_scale = 0.66
self.length_scale = 1.
self.external_speaker_embedding_dim = external_speaker_embedding_dim
# if is a multispeaker and c_in_channels is 0, set to 256
if num_speakers > 1:
if self.c_in_channels == 0 and not external_speaker_embedding_dim:
self.c_in_channels = 256
elif external_speaker_embedding_dim:
self.c_in_channels = external_speaker_embedding_dim
if self.c_in_channels == 0 and not self.external_speaker_embedding_dim:
self.c_in_channels = 512
elif self.external_speaker_embedding_dim:
self.c_in_channels = self.external_speaker_embedding_dim
self.encoder = Encoder(num_chars,
out_channels=out_channels,
@ -88,7 +89,7 @@ class GlowTts(nn.Module):
dropout_p=dropout_p,
mean_only=mean_only,
use_prenet=use_encoder_prenet,
c_in_channels=c_in_channels)
c_in_channels=self.c_in_channels)
self.decoder = Decoder(out_channels,
hidden_channels_dec or hidden_channels,
@ -100,10 +101,10 @@ class GlowTts(nn.Module):
num_splits=num_splits,
num_sqz=num_sqz,
sigmoid_scale=sigmoid_scale,
c_in_channels=c_in_channels)
c_in_channels=self.c_in_channels)
if num_speakers > 1 and not external_speaker_embedding_dim:
self.emb_g = nn.Embedding(num_speakers, c_in_channels)
self.emb_g = nn.Embedding(num_speakers, self.c_in_channels)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
@staticmethod
@ -130,7 +131,11 @@ class GlowTts(nn.Module):
y_max_length = y.size(2)
# norm speaker embeddings
if g is not None:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
if self.external_speaker_embedding_dim:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1)# [b, h]
# embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
x_lengths,
@ -165,8 +170,13 @@ class GlowTts(nn.Module):
@torch.no_grad()
def inference(self, x, x_lengths, g=None):
if g is not None:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
if self.external_speaker_embedding_dim:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
# embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
x_lengths,

View File

@ -126,7 +126,8 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
mean_only=True,
hidden_channels_enc=192,
hidden_channels_dec=192,
use_encoder_prenet=True)
use_encoder_prenet=True,
external_speaker_embedding_dim=speaker_embedding_dim)
return model
def is_tacotron(c):

View File

@ -59,7 +59,7 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
elif 'glow' in CONFIG.model.lower():
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths)
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id else speaker_embeddings)
postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models.
decoder_output = None