Merge pull request #545 from Edresson/dev

GlowTTS zeroshot TTS support
This commit is contained in:
Eren Gölge 2020-10-27 15:23:41 +01:00 committed by GitHub
commit f4b8170bd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 268 additions and 135 deletions

View File

@ -9,17 +9,17 @@ import time
import traceback
import torch
from random import randrange
from torch.utils.data import DataLoader
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.layers.losses import GlowTTSLoss
from TTS.tts.utils.distribute import (DistributedSampler, init_distributed,
reduce_tensor)
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.generic_utils import setup_model, check_config_tts
from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping,
save_speaker_mapping)
from TTS.tts.utils.speakers import parse_speakers, load_speaker_mapping
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -36,8 +36,7 @@ from TTS.utils.training import (NoamLR, check_update,
use_cuda, num_gpus = setup_torch_training_env(True, False)
def setup_loader(ap, r, is_val=False, verbose=False):
def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None):
if is_val and not c.run_eval:
loader = None
else:
@ -48,6 +47,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap,
tp=c.characters if 'characters' in c.keys() else None,
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
batch_group_size=0 if is_val else c.batch_group_size *
c.batch_size,
min_seq_len=c.min_seq_len,
@ -56,7 +56,8 @@ def setup_loader(ap, r, is_val=False, verbose=False):
use_phonemes=c.use_phonemes,
phoneme_language=c.phoneme_language,
enable_eos_bos=c.enable_eos_bos_chars,
verbose=verbose)
verbose=verbose,
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(
dataset,
@ -86,10 +87,13 @@ def format_data(data):
avg_spec_length = torch.mean(mel_lengths.float())
if c.use_speaker_embedding:
speaker_ids = [
speaker_mapping[speaker_name] for speaker_name in speaker_names
]
speaker_ids = torch.LongTensor(speaker_ids)
if c.use_external_speaker_embedding_file:
speaker_ids = data[8]
else:
speaker_ids = [
speaker_mapping[speaker_name] for speaker_name in speaker_names
]
speaker_ids = torch.LongTensor(speaker_ids)
else:
speaker_ids = None
@ -107,7 +111,7 @@ def format_data(data):
avg_text_length, avg_spec_length, attn_mask
def data_depended_init(model, ap):
def data_depended_init(model, ap, speaker_mapping=None):
"""Data depended initialization for activation normalization."""
if hasattr(model, 'module'):
for f in model.module.decoder.flows:
@ -118,19 +122,19 @@ def data_depended_init(model, ap):
if getattr(f, "set_ddi", False):
f.set_ddi(True)
data_loader = setup_loader(ap, 1, is_val=False)
data_loader = setup_loader(ap, 1, is_val=False, speaker_mapping=speaker_mapping)
model.train()
print(" > Data depended initialization ... ")
with torch.no_grad():
for _, data in enumerate(data_loader):
# format data
text_input, text_lengths, mel_input, mel_lengths, _,\
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
_, _, attn_mask = format_data(data)
# forward pass model
_ = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask)
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
break
if hasattr(model, 'module'):
@ -145,9 +149,9 @@ def data_depended_init(model, ap):
def train(model, criterion, optimizer, scheduler,
ap, global_step, epoch, amp):
ap, global_step, epoch, amp, speaker_mapping=None):
data_loader = setup_loader(ap, 1, is_val=False,
verbose=(epoch == 0))
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
model.train()
epoch_time = 0
keep_avg = KeepAverage()
@ -162,7 +166,7 @@ def train(model, criterion, optimizer, scheduler,
start_time = time.time()
# format data
text_input, text_lengths, mel_input, mel_lengths, _,\
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
avg_text_length, avg_spec_length, attn_mask = format_data(data)
loader_time = time.time() - end_time
@ -176,7 +180,7 @@ def train(model, criterion, optimizer, scheduler,
# forward pass model
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask)
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
# compute loss
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
@ -262,7 +266,7 @@ def train(model, criterion, optimizer, scheduler,
# Diagnostic visualizations
# direct pass on model for spec predictions
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1])
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
spec_pred = spec_pred.permute(0, 2, 1)
gt_spec = mel_input.permute(0, 2, 1)
const_spec = spec_pred[0].data.cpu().numpy()
@ -298,8 +302,8 @@ def train(model, criterion, optimizer, scheduler,
@torch.no_grad()
def evaluate(model, criterion, ap, global_step, epoch):
data_loader = setup_loader(ap, 1, is_val=True)
def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping):
data_loader = setup_loader(ap, 1, is_val=True, speaker_mapping=speaker_mapping)
model.eval()
epoch_time = 0
keep_avg = KeepAverage()
@ -309,12 +313,12 @@ def evaluate(model, criterion, ap, global_step, epoch):
start_time = time.time()
# format data
text_input, text_lengths, mel_input, mel_lengths, _,\
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
_, _, attn_mask = format_data(data)
# forward pass model
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask)
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
# compute loss
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
@ -356,9 +360,9 @@ def evaluate(model, criterion, ap, global_step, epoch):
# Diagnostic visualizations
# direct pass on model for spec predictions
if hasattr(model, 'module'):
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1])
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
else:
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1])
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
spec_pred = spec_pred.permute(0, 2, 1)
gt_spec = mel_input.permute(0, 2, 1)
@ -398,7 +402,17 @@ def evaluate(model, criterion, ap, global_step, epoch):
test_audios = {}
test_figures = {}
print(" | > Synthesizing test sentences")
speaker_id = 0 if c.use_speaker_embedding else None
if c.use_speaker_embedding:
if c.use_external_speaker_embedding_file:
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding']
speaker_id = None
else:
speaker_id = 0
speaker_embedding = None
else:
speaker_id = None
speaker_embedding = None
style_wav = c.get("style_wav_for_test")
for idx, test_sentence in enumerate(test_sentences):
try:
@ -409,6 +423,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
use_cuda,
ap,
speaker_id=speaker_id,
speaker_embedding=speaker_embedding,
style_wav=style_wav,
truncated=False,
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
@ -459,26 +474,10 @@ def main(args): # pylint: disable=redefined-outer-name
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
# parse speakers
if c.use_speaker_embedding:
speakers = get_speakers(meta_data_train)
if args.restore_path:
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
assert all([speaker in speaker_mapping
for speaker in speakers]), "As of now you, you cannot " \
"introduce new speakers to " \
"a previously trained model."
else:
speaker_mapping = {name: i for i, name in enumerate(speakers)}
save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping)
print("Training with {} speakers: {}".format(num_speakers,
", ".join(speakers)))
else:
num_speakers = 0
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
# setup model
model = setup_model(num_chars, num_speakers, c)
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim)
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
criterion = GlowTTSLoss()
@ -540,13 +539,13 @@ def main(args): # pylint: disable=redefined-outer-name
best_loss = float('inf')
global_step = args.restore_step
model = data_depended_init(model, ap)
model = data_depended_init(model, ap, speaker_mapping)
for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs)
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
scheduler, ap, global_step,
epoch, amp)
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
epoch, amp, speaker_mapping)
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=speaker_mapping)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_loss']
if c.run_eval:
@ -602,6 +601,7 @@ if __name__ == '__main__':
# setup output paths and read configs
c = load_config(args.config_path)
# check_config(c)
check_config_tts(c)
_ = os.path.dirname(os.path.realpath(__file__))
if c.apex_amp_level:

View File

@ -22,8 +22,7 @@ from TTS.tts.utils.distribute import (DistributedSampler,
from TTS.tts.utils.generic_utils import setup_model, check_config_tts
from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping,
save_speaker_mapping)
from TTS.tts.utils.speakers import parse_speakers, load_speaker_mapping
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -52,6 +51,7 @@ def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None):
meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap,
tp=c.characters if 'characters' in c.keys() else None,
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
batch_group_size=0 if is_val else c.batch_group_size *
c.batch_size,
min_seq_len=c.min_seq_len,
@ -502,42 +502,7 @@ def main(args): # pylint: disable=redefined-outer-name
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
# parse speakers
if c.use_speaker_embedding:
speakers = get_speakers(meta_data_train)
if args.restore_path:
if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
if not speaker_mapping:
print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file")
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
if not speaker_mapping:
raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file")
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
speaker_embedding_dim = None
assert all([speaker in speaker_mapping
for speaker in speakers]), "As of now you, you cannot " \
"introduce new speakers to " \
"a previously trained model."
elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file
raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
else: # if start new train and don't use External Embedding file
speaker_mapping = {name: i for i, name in enumerate(speakers)}
speaker_embedding_dim = None
save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping)
print("Training with {} speakers: {}".format(num_speakers,
", ".join(speakers)))
else:
num_speakers = 0
speaker_embedding_dim = None
speaker_mapping = None
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim)

View File

@ -51,6 +51,8 @@
// "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ"
// },
"add_blank": false, // if true add a new token after each token of the sentence. This increases the size of the input sequence, but has considerably improved the prosody of the GlowTTS model.
// DISTRIBUTED TRAINING
"distributed":{
"backend": "nccl",

View File

@ -51,6 +51,8 @@
// "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ"
// },
"add_blank": false, // if true add a new token after each token of the sentence. This increases the size of the input sequence, but has considerably improved the prosody of the GlowTTS model.
// DISTRIBUTED TRAINING
"distributed":{
"backend": "nccl",

View File

@ -17,6 +17,7 @@ class MyDataset(Dataset):
ap,
meta_data,
tp=None,
add_blank=False,
batch_group_size=0,
min_seq_len=0,
max_seq_len=float("inf"),
@ -55,6 +56,7 @@ class MyDataset(Dataset):
self.max_seq_len = max_seq_len
self.ap = ap
self.tp = tp
self.add_blank = add_blank
self.use_phonemes = use_phonemes
self.phoneme_cache_path = phoneme_cache_path
self.phoneme_language = phoneme_language
@ -88,7 +90,7 @@ class MyDataset(Dataset):
phonemes = phoneme_to_sequence(text, [self.cleaners],
language=self.phoneme_language,
enable_eos_bos=False,
tp=self.tp)
tp=self.tp, add_blank=self.add_blank)
phonemes = np.asarray(phonemes, dtype=np.int32)
np.save(cache_path, phonemes)
return phonemes
@ -127,7 +129,7 @@ class MyDataset(Dataset):
text = self._load_or_generate_phoneme_sequence(wav_file, text)
else:
text = np.asarray(text_to_sequence(text, [self.cleaners],
tp=self.tp),
tp=self.tp, add_blank=self.add_blank),
dtype=np.int32)
assert text.size > 0, self.items[idx][1]

View File

@ -37,7 +37,8 @@ class GlowTts(nn.Module):
hidden_channels_enc=None,
hidden_channels_dec=None,
use_encoder_prenet=False,
encoder_type="transformer"):
encoder_type="transformer",
external_speaker_embedding_dim=None):
super().__init__()
self.num_chars = num_chars
@ -67,6 +68,14 @@ class GlowTts(nn.Module):
self.use_encoder_prenet = use_encoder_prenet
self.noise_scale = 0.66
self.length_scale = 1.
self.external_speaker_embedding_dim = external_speaker_embedding_dim
# if is a multispeaker and c_in_channels is 0, set to 256
if num_speakers > 1:
if self.c_in_channels == 0 and not self.external_speaker_embedding_dim:
self.c_in_channels = 512
elif self.external_speaker_embedding_dim:
self.c_in_channels = self.external_speaker_embedding_dim
self.encoder = Encoder(num_chars,
out_channels=out_channels,
@ -80,7 +89,7 @@ class GlowTts(nn.Module):
dropout_p=dropout_p,
mean_only=mean_only,
use_prenet=use_encoder_prenet,
c_in_channels=c_in_channels)
c_in_channels=self.c_in_channels)
self.decoder = Decoder(out_channels,
hidden_channels_dec or hidden_channels,
@ -92,10 +101,10 @@ class GlowTts(nn.Module):
num_splits=num_splits,
num_sqz=num_sqz,
sigmoid_scale=sigmoid_scale,
c_in_channels=c_in_channels)
c_in_channels=self.c_in_channels)
if num_speakers > 1:
self.emb_g = nn.Embedding(num_speakers, c_in_channels)
if num_speakers > 1 and not external_speaker_embedding_dim:
self.emb_g = nn.Embedding(num_speakers, self.c_in_channels)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
@staticmethod
@ -122,7 +131,11 @@ class GlowTts(nn.Module):
y_max_length = y.size(2)
# norm speaker embeddings
if g is not None:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
if self.external_speaker_embedding_dim:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1)# [b, h]
# embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
x_lengths,
@ -157,8 +170,13 @@ class GlowTts(nn.Module):
@torch.no_grad()
def inference(self, x, x_lengths, g=None):
if g is not None:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
if self.external_speaker_embedding_dim:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
# embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
x_lengths,

View File

@ -126,13 +126,15 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
mean_only=True,
hidden_channels_enc=192,
hidden_channels_dec=192,
use_encoder_prenet=True)
use_encoder_prenet=True,
external_speaker_embedding_dim=speaker_embedding_dim)
return model
def is_tacotron(c):
return False if c['model'] == 'glow_tts' else True
def check_config_tts(c):
check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts'], restricted=True, val_type=str)
check_argument('run_name', c, restricted=True, val_type=str)
check_argument('run_description', c, val_type=str)
@ -195,27 +197,30 @@ def check_config_tts(c):
check_argument('seq_len_norm', c, restricted=True, val_type=bool)
# tacotron prenet
check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1)
check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn'])
check_argument('prenet_dropout', c, restricted=True, val_type=bool)
check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1)
check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn'])
check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool)
# attention
check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original'])
check_argument('attention_heads', c, restricted=True, val_type=int)
check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax'])
check_argument('windowing', c, restricted=True, val_type=bool)
check_argument('use_forward_attn', c, restricted=True, val_type=bool)
check_argument('forward_attn_mask', c, restricted=True, val_type=bool)
check_argument('transition_agent', c, restricted=True, val_type=bool)
check_argument('transition_agent', c, restricted=True, val_type=bool)
check_argument('location_attn', c, restricted=True, val_type=bool)
check_argument('bidirectional_decoder', c, restricted=True, val_type=bool)
check_argument('double_decoder_consistency', c, restricted=True, val_type=bool)
check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original'])
check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int)
check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax'])
check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool)
check_argument('use_forward_attn', c, restricted=is_tacotron(c), val_type=bool)
check_argument('forward_attn_mask', c, restricted=is_tacotron(c), val_type=bool)
check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
check_argument('location_attn', c, restricted=is_tacotron(c), val_type=bool)
check_argument('bidirectional_decoder', c, restricted=is_tacotron(c), val_type=bool)
check_argument('double_decoder_consistency', c, restricted=is_tacotron(c), val_type=bool)
check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
# stopnet
check_argument('stopnet', c, restricted=True, val_type=bool)
check_argument('separate_stopnet', c, restricted=True, val_type=bool)
check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool)
check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool)
# GlowTTS parameters
check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str)
# tensorboard
check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
@ -240,15 +245,16 @@ def check_config_tts(c):
# multi-speaker and gst
check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
check_argument('use_external_speaker_embedding_file', c, restricted=True, val_type=bool)
check_argument('external_speaker_embedding_file', c, restricted=True, val_type=str)
check_argument('use_gst', c, restricted=True, val_type=bool)
check_argument('gst', c, restricted=True, val_type=dict)
check_argument('gst_style_input', c['gst'], restricted=True, val_type=[str, dict])
check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=0, max_val=1000)
check_argument('gst_use_speaker_embedding', c['gst'], restricted=True, val_type=bool)
check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=2, max_val=10)
check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1, max_val=1000)
check_argument('use_external_speaker_embedding_file', c, restricted=True if c['use_speaker_embedding'] else False, val_type=bool)
check_argument('external_speaker_embedding_file', c, restricted=True if c['use_external_speaker_embedding_file'] else False, val_type=str)
check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool)
if c['use_gst']:
check_argument('gst', c, restricted=is_tacotron(c), val_type=dict)
check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict])
check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool)
check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10)
check_argument('gst_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000)
# datasets - checking only the first entry
check_argument('datasets', c, restricted=True, val_type=list)

View File

@ -30,3 +30,44 @@ def get_speakers(items):
"""Returns a sorted, unique list of speakers in a given dataset."""
speakers = {e[2] for e in items}
return sorted(speakers)
def parse_speakers(c, args, meta_data_train, OUT_PATH):
""" Returns number of speakers, speaker embedding shape and speaker mapping"""
if c.use_speaker_embedding:
speakers = get_speakers(meta_data_train)
if args.restore_path:
if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
if not speaker_mapping:
print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file")
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
if not speaker_mapping:
raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file")
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
speaker_embedding_dim = None
assert all([speaker in speaker_mapping
for speaker in speakers]), "As of now you, you cannot " \
"introduce new speakers to " \
"a previously trained model."
elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file
raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
else: # if start new train and don't use External Embedding file
speaker_mapping = {name: i for i, name in enumerate(speakers)}
speaker_embedding_dim = None
save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping)
print("Training with {} speakers: {}".format(len(speakers),
", ".join(speakers)))
else:
num_speakers = 0
speaker_embedding_dim = None
speaker_mapping = None
return num_speakers, speaker_embedding_dim, speaker_mapping

View File

@ -14,10 +14,13 @@ def text_to_seqvec(text, CONFIG):
seq = np.asarray(
phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language,
CONFIG.enable_eos_bos_chars,
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None),
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None,
add_blank=CONFIG['add_blank'] if 'add_blank' in CONFIG.keys() else False),
dtype=np.int32)
else:
seq = np.asarray(text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), dtype=np.int32)
seq = np.asarray(
text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None,
add_blank=CONFIG['add_blank'] if 'add_blank' in CONFIG.keys() else False), dtype=np.int32)
return seq
@ -59,7 +62,7 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
elif 'glow' in CONFIG.model.lower():
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths)
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id else speaker_embeddings)
postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models.
decoder_output = None

View File

@ -16,6 +16,8 @@ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
_phonemes_to_id = {s: i for i, s in enumerate(phonemes)}
_id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
_symbols = symbols
_phonemes = phonemes
# Regular expression matching text enclosed in curly braces:
_CURLY_RE = re.compile(r'(.*?)\{(.+?)\}(.*)')
@ -57,6 +59,10 @@ def text2phone(text, language):
return ph
def intersperse(sequence, token):
result = [token] * (len(sequence) * 2 + 1)
result[1::2] = sequence
return result
def pad_with_eos_bos(phoneme_sequence, tp=None):
# pylint: disable=global-statement
@ -69,10 +75,9 @@ def pad_with_eos_bos(phoneme_sequence, tp=None):
return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]]
def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None):
def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None, add_blank=False):
# pylint: disable=global-statement
global _phonemes_to_id
global _phonemes_to_id, _phonemes
if tp:
_, _phonemes = make_symbols(**tp)
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
@ -88,13 +93,17 @@ def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=
# Append EOS char
if enable_eos_bos:
sequence = pad_with_eos_bos(sequence, tp=tp)
if add_blank:
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
return sequence
def sequence_to_phoneme(sequence, tp=None):
def sequence_to_phoneme(sequence, tp=None, add_blank=False):
# pylint: disable=global-statement
'''Converts a sequence of IDs back to a string'''
global _id_to_phonemes
global _id_to_phonemes, _phonemes
if add_blank:
sequence = list(filter(lambda x: x != len(_phonemes), sequence))
result = ''
if tp:
_, _phonemes = make_symbols(**tp)
@ -107,7 +116,7 @@ def sequence_to_phoneme(sequence, tp=None):
return result.replace('}{', ' ')
def text_to_sequence(text, cleaner_names, tp=None):
def text_to_sequence(text, cleaner_names, tp=None, add_blank=False):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
@ -121,7 +130,7 @@ def text_to_sequence(text, cleaner_names, tp=None):
List of integers corresponding to the symbols in the text
'''
# pylint: disable=global-statement
global _symbol_to_id
global _symbol_to_id, _symbols
if tp:
_symbols, _ = make_symbols(**tp)
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
@ -137,13 +146,19 @@ def text_to_sequence(text, cleaner_names, tp=None):
_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
if add_blank:
sequence = intersperse(sequence, len(_symbols)) # add a blank token (new), whose id number is len(_symbols)
return sequence
def sequence_to_text(sequence, tp=None):
def sequence_to_text(sequence, tp=None, add_blank=False):
'''Converts a sequence of IDs back to a string'''
# pylint: disable=global-statement
global _id_to_symbol
global _id_to_symbol, _symbols
if add_blank:
sequence = list(filter(lambda x: x != len(_symbols), sequence))
if tp:
_symbols, _ = make_symbols(**tp)
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}

View File

@ -11,6 +11,7 @@ from TTS.utils.io import load_config
conf = load_config(os.path.join(get_tests_input_path(), 'test_config.json'))
def test_phoneme_to_sequence():
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
text_cleaner = ["phoneme_cleaners"]
lang = "en-us"
@ -20,7 +21,7 @@ def test_phoneme_to_sequence():
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
gt = "ɹiːsənt ɹɪːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
assert text_hat == text_hat_with_params == gt
# multiple punctuations
text = "Be a voice, not an! echo?"
sequence = phoneme_to_sequence(text, text_cleaner, lang)
@ -87,6 +88,84 @@ def test_phoneme_to_sequence():
print(len(sequence))
assert text_hat == text_hat_with_params == gt
def test_phoneme_to_sequence_with_blank_token():
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
text_cleaner = ["phoneme_cleaners"]
lang = "en-us"
sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "ɹiːsənt ɹɪːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
assert text_hat == text_hat_with_params == gt
# multiple punctuations
text = "Be a voice, not an! echo?"
sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?"
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
# not ending with punctuation
text = "Be a voice, not an! echo"
sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
# original
text = "Be a voice, not an echo!"
sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!"
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
# extra space after the sentence
text = "Be a voice, not an! echo. "
sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ."
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
# extra space after the sentence
text = "Be a voice, not an! echo. "
sequence = phoneme_to_sequence(text, text_cleaner, lang, True)
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~"
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
# padding char
text = "_Be a _voice, not an! echo_"
sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
def test_text2phone():
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!"