mirror of https://github.com/coqui-ai/TTS.git
GlowTTS zero-shot TTS Support
This commit is contained in:
parent
b7f9ebd32b
commit
07345099ee
|
@ -9,6 +9,7 @@ import time
|
|||
import traceback
|
||||
|
||||
import torch
|
||||
from random import randrange
|
||||
from torch.utils.data import DataLoader
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
|
@ -36,8 +37,7 @@ from TTS.utils.training import (NoamLR, check_update,
|
|||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||
|
||||
def setup_loader(ap, r, is_val=False, verbose=False):
|
||||
|
||||
def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None):
|
||||
if is_val and not c.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
|
@ -56,7 +56,8 @@ def setup_loader(ap, r, is_val=False, verbose=False):
|
|||
use_phonemes=c.use_phonemes,
|
||||
phoneme_language=c.phoneme_language,
|
||||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
|
@ -86,10 +87,13 @@ def format_data(data):
|
|||
avg_spec_length = torch.mean(mel_lengths.float())
|
||||
|
||||
if c.use_speaker_embedding:
|
||||
speaker_ids = [
|
||||
speaker_mapping[speaker_name] for speaker_name in speaker_names
|
||||
]
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
if c.use_external_speaker_embedding_file:
|
||||
speaker_ids = data[8]
|
||||
else:
|
||||
speaker_ids = [
|
||||
speaker_mapping[speaker_name] for speaker_name in speaker_names
|
||||
]
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
else:
|
||||
speaker_ids = None
|
||||
|
||||
|
@ -107,7 +111,7 @@ def format_data(data):
|
|||
avg_text_length, avg_spec_length, attn_mask
|
||||
|
||||
|
||||
def data_depended_init(model, ap):
|
||||
def data_depended_init(model, ap, speaker_mapping=None):
|
||||
"""Data depended initialization for activation normalization."""
|
||||
if hasattr(model, 'module'):
|
||||
for f in model.module.decoder.flows:
|
||||
|
@ -118,19 +122,19 @@ def data_depended_init(model, ap):
|
|||
if getattr(f, "set_ddi", False):
|
||||
f.set_ddi(True)
|
||||
|
||||
data_loader = setup_loader(ap, 1, is_val=False)
|
||||
data_loader = setup_loader(ap, 1, is_val=False, speaker_mapping=speaker_mapping)
|
||||
model.train()
|
||||
print(" > Data depended initialization ... ")
|
||||
with torch.no_grad():
|
||||
for _, data in enumerate(data_loader):
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_input, mel_lengths, _,\
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
|
||||
_, _, attn_mask = format_data(data)
|
||||
|
||||
# forward pass model
|
||||
_ = model.forward(
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask)
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
|
||||
break
|
||||
|
||||
if hasattr(model, 'module'):
|
||||
|
@ -145,9 +149,9 @@ def data_depended_init(model, ap):
|
|||
|
||||
|
||||
def train(model, criterion, optimizer, scheduler,
|
||||
ap, global_step, epoch, amp):
|
||||
ap, global_step, epoch, amp, speaker_mapping=None):
|
||||
data_loader = setup_loader(ap, 1, is_val=False,
|
||||
verbose=(epoch == 0))
|
||||
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
|
@ -162,7 +166,7 @@ def train(model, criterion, optimizer, scheduler,
|
|||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_input, mel_lengths, _,\
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
|
||||
avg_text_length, avg_spec_length, attn_mask = format_data(data)
|
||||
|
||||
loader_time = time.time() - end_time
|
||||
|
@ -176,7 +180,7 @@ def train(model, criterion, optimizer, scheduler,
|
|||
|
||||
# forward pass model
|
||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask)
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
||||
|
@ -262,7 +266,7 @@ def train(model, criterion, optimizer, scheduler,
|
|||
|
||||
# Diagnostic visualizations
|
||||
# direct pass on model for spec predictions
|
||||
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1])
|
||||
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
|
||||
spec_pred = spec_pred.permute(0, 2, 1)
|
||||
gt_spec = mel_input.permute(0, 2, 1)
|
||||
const_spec = spec_pred[0].data.cpu().numpy()
|
||||
|
@ -298,8 +302,8 @@ def train(model, criterion, optimizer, scheduler,
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model, criterion, ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, 1, is_val=True)
|
||||
def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping):
|
||||
data_loader = setup_loader(ap, 1, is_val=True, speaker_mapping=speaker_mapping)
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
|
@ -309,12 +313,12 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_input, mel_lengths, _,\
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
|
||||
_, _, attn_mask = format_data(data)
|
||||
|
||||
# forward pass model
|
||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask)
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
||||
|
@ -356,9 +360,9 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
# Diagnostic visualizations
|
||||
# direct pass on model for spec predictions
|
||||
if hasattr(model, 'module'):
|
||||
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1])
|
||||
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
|
||||
else:
|
||||
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1])
|
||||
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
|
||||
spec_pred = spec_pred.permute(0, 2, 1)
|
||||
gt_spec = mel_input.permute(0, 2, 1)
|
||||
|
||||
|
@ -398,7 +402,17 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
test_audios = {}
|
||||
test_figures = {}
|
||||
print(" | > Synthesizing test sentences")
|
||||
speaker_id = 0 if c.use_speaker_embedding else None
|
||||
if c.use_speaker_embedding:
|
||||
if c.use_external_speaker_embedding_file:
|
||||
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding']
|
||||
speaker_id = None
|
||||
else:
|
||||
speaker_id = 0
|
||||
speaker_embedding = None
|
||||
else:
|
||||
speaker_id = None
|
||||
speaker_embedding = None
|
||||
|
||||
style_wav = c.get("style_wav_for_test")
|
||||
for idx, test_sentence in enumerate(test_sentences):
|
||||
try:
|
||||
|
@ -409,6 +423,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
use_cuda,
|
||||
ap,
|
||||
speaker_id=speaker_id,
|
||||
speaker_embedding=speaker_embedding,
|
||||
style_wav=style_wav,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
|
||||
|
@ -462,23 +477,42 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if c.use_speaker_embedding:
|
||||
speakers = get_speakers(meta_data_train)
|
||||
if args.restore_path:
|
||||
prev_out_path = os.path.dirname(args.restore_path)
|
||||
speaker_mapping = load_speaker_mapping(prev_out_path)
|
||||
assert all([speaker in speaker_mapping
|
||||
for speaker in speakers]), "As of now you, you cannot " \
|
||||
"introduce new speakers to " \
|
||||
"a previously trained model."
|
||||
else:
|
||||
if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file
|
||||
prev_out_path = os.path.dirname(args.restore_path)
|
||||
speaker_mapping = load_speaker_mapping(prev_out_path)
|
||||
if not speaker_mapping:
|
||||
print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file")
|
||||
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
|
||||
if not speaker_mapping:
|
||||
raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file")
|
||||
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
|
||||
elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file
|
||||
prev_out_path = os.path.dirname(args.restore_path)
|
||||
speaker_mapping = load_speaker_mapping(prev_out_path)
|
||||
speaker_embedding_dim = None
|
||||
assert all([speaker in speaker_mapping
|
||||
for speaker in speakers]), "As of now you, you cannot " \
|
||||
"introduce new speakers to " \
|
||||
"a previously trained model."
|
||||
elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file
|
||||
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
|
||||
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
|
||||
elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file
|
||||
raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
|
||||
else: # if start new train and don't use External Embedding file
|
||||
speaker_mapping = {name: i for i, name in enumerate(speakers)}
|
||||
speaker_embedding_dim = None
|
||||
save_speaker_mapping(OUT_PATH, speaker_mapping)
|
||||
num_speakers = len(speaker_mapping)
|
||||
print("Training with {} speakers: {}".format(num_speakers,
|
||||
print("Training with {} speakers: {}".format(len(speakers),
|
||||
", ".join(speakers)))
|
||||
else:
|
||||
num_speakers = 0
|
||||
speaker_embedding_dim = None
|
||||
speaker_mapping = None
|
||||
|
||||
# setup model
|
||||
model = setup_model(num_chars, num_speakers, c)
|
||||
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim)
|
||||
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
|
||||
criterion = GlowTTSLoss()
|
||||
|
||||
|
@ -540,13 +574,13 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
best_loss = float('inf')
|
||||
|
||||
global_step = args.restore_step
|
||||
model = data_depended_init(model, ap)
|
||||
model = data_depended_init(model, ap, speaker_mapping)
|
||||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
||||
scheduler, ap, global_step,
|
||||
epoch, amp)
|
||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
||||
epoch, amp, speaker_mapping)
|
||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=speaker_mapping)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = train_avg_loss_dict['avg_loss']
|
||||
if c.run_eval:
|
||||
|
|
|
@ -68,13 +68,14 @@ class GlowTts(nn.Module):
|
|||
self.use_encoder_prenet = use_encoder_prenet
|
||||
self.noise_scale = 0.66
|
||||
self.length_scale = 1.
|
||||
self.external_speaker_embedding_dim = external_speaker_embedding_dim
|
||||
|
||||
# if is a multispeaker and c_in_channels is 0, set to 256
|
||||
if num_speakers > 1:
|
||||
if self.c_in_channels == 0 and not external_speaker_embedding_dim:
|
||||
self.c_in_channels = 256
|
||||
elif external_speaker_embedding_dim:
|
||||
self.c_in_channels = external_speaker_embedding_dim
|
||||
if self.c_in_channels == 0 and not self.external_speaker_embedding_dim:
|
||||
self.c_in_channels = 512
|
||||
elif self.external_speaker_embedding_dim:
|
||||
self.c_in_channels = self.external_speaker_embedding_dim
|
||||
|
||||
self.encoder = Encoder(num_chars,
|
||||
out_channels=out_channels,
|
||||
|
@ -88,7 +89,7 @@ class GlowTts(nn.Module):
|
|||
dropout_p=dropout_p,
|
||||
mean_only=mean_only,
|
||||
use_prenet=use_encoder_prenet,
|
||||
c_in_channels=c_in_channels)
|
||||
c_in_channels=self.c_in_channels)
|
||||
|
||||
self.decoder = Decoder(out_channels,
|
||||
hidden_channels_dec or hidden_channels,
|
||||
|
@ -100,10 +101,10 @@ class GlowTts(nn.Module):
|
|||
num_splits=num_splits,
|
||||
num_sqz=num_sqz,
|
||||
sigmoid_scale=sigmoid_scale,
|
||||
c_in_channels=c_in_channels)
|
||||
c_in_channels=self.c_in_channels)
|
||||
|
||||
if num_speakers > 1 and not external_speaker_embedding_dim:
|
||||
self.emb_g = nn.Embedding(num_speakers, c_in_channels)
|
||||
self.emb_g = nn.Embedding(num_speakers, self.c_in_channels)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
|
||||
@staticmethod
|
||||
|
@ -130,7 +131,11 @@ class GlowTts(nn.Module):
|
|||
y_max_length = y.size(2)
|
||||
# norm speaker embeddings
|
||||
if g is not None:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
||||
if self.external_speaker_embedding_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
else:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1)# [b, h]
|
||||
|
||||
# embedding pass
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
||||
x_lengths,
|
||||
|
@ -165,8 +170,13 @@ class GlowTts(nn.Module):
|
|||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, x_lengths, g=None):
|
||||
|
||||
if g is not None:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
||||
if self.external_speaker_embedding_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
else:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
||||
|
||||
# embedding pass
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
||||
x_lengths,
|
||||
|
|
|
@ -126,7 +126,8 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
mean_only=True,
|
||||
hidden_channels_enc=192,
|
||||
hidden_channels_dec=192,
|
||||
use_encoder_prenet=True)
|
||||
use_encoder_prenet=True,
|
||||
external_speaker_embedding_dim=speaker_embedding_dim)
|
||||
return model
|
||||
|
||||
def is_tacotron(c):
|
||||
|
|
|
@ -59,7 +59,7 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
|
|||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
|
||||
elif 'glow' in CONFIG.model.lower():
|
||||
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
||||
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths)
|
||||
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id else speaker_embeddings)
|
||||
postnet_output = postnet_output.permute(0, 2, 1)
|
||||
# these only belong to tacotron models.
|
||||
decoder_output = None
|
||||
|
|
Loading…
Reference in New Issue