Merge pull request #545 from Edresson/dev

GlowTTS zeroshot TTS support
This commit is contained in:
Eren Gölge 2020-10-27 15:23:41 +01:00 committed by GitHub
commit f4b8170bd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 268 additions and 135 deletions

View File

@ -9,17 +9,17 @@ import time
import traceback import traceback
import torch import torch
from random import randrange
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.layers.losses import GlowTTSLoss
from TTS.tts.utils.distribute import (DistributedSampler, init_distributed, from TTS.tts.utils.distribute import (DistributedSampler, init_distributed,
reduce_tensor) reduce_tensor)
from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.generic_utils import setup_model, check_config_tts
from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping, from TTS.tts.utils.speakers import parse_speakers, load_speaker_mapping
save_speaker_mapping)
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -36,8 +36,7 @@ from TTS.utils.training import (NoamLR, check_update,
use_cuda, num_gpus = setup_torch_training_env(True, False) use_cuda, num_gpus = setup_torch_training_env(True, False)
def setup_loader(ap, r, is_val=False, verbose=False): def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None):
if is_val and not c.run_eval: if is_val and not c.run_eval:
loader = None loader = None
else: else:
@ -48,6 +47,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
meta_data=meta_data_eval if is_val else meta_data_train, meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap, ap=ap,
tp=c.characters if 'characters' in c.keys() else None, tp=c.characters if 'characters' in c.keys() else None,
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
batch_group_size=0 if is_val else c.batch_group_size * batch_group_size=0 if is_val else c.batch_group_size *
c.batch_size, c.batch_size,
min_seq_len=c.min_seq_len, min_seq_len=c.min_seq_len,
@ -56,7 +56,8 @@ def setup_loader(ap, r, is_val=False, verbose=False):
use_phonemes=c.use_phonemes, use_phonemes=c.use_phonemes,
phoneme_language=c.phoneme_language, phoneme_language=c.phoneme_language,
enable_eos_bos=c.enable_eos_bos_chars, enable_eos_bos=c.enable_eos_bos_chars,
verbose=verbose) verbose=verbose,
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader( loader = DataLoader(
dataset, dataset,
@ -86,10 +87,13 @@ def format_data(data):
avg_spec_length = torch.mean(mel_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float())
if c.use_speaker_embedding: if c.use_speaker_embedding:
speaker_ids = [ if c.use_external_speaker_embedding_file:
speaker_mapping[speaker_name] for speaker_name in speaker_names speaker_ids = data[8]
] else:
speaker_ids = torch.LongTensor(speaker_ids) speaker_ids = [
speaker_mapping[speaker_name] for speaker_name in speaker_names
]
speaker_ids = torch.LongTensor(speaker_ids)
else: else:
speaker_ids = None speaker_ids = None
@ -107,7 +111,7 @@ def format_data(data):
avg_text_length, avg_spec_length, attn_mask avg_text_length, avg_spec_length, attn_mask
def data_depended_init(model, ap): def data_depended_init(model, ap, speaker_mapping=None):
"""Data depended initialization for activation normalization.""" """Data depended initialization for activation normalization."""
if hasattr(model, 'module'): if hasattr(model, 'module'):
for f in model.module.decoder.flows: for f in model.module.decoder.flows:
@ -118,19 +122,19 @@ def data_depended_init(model, ap):
if getattr(f, "set_ddi", False): if getattr(f, "set_ddi", False):
f.set_ddi(True) f.set_ddi(True)
data_loader = setup_loader(ap, 1, is_val=False) data_loader = setup_loader(ap, 1, is_val=False, speaker_mapping=speaker_mapping)
model.train() model.train()
print(" > Data depended initialization ... ") print(" > Data depended initialization ... ")
with torch.no_grad(): with torch.no_grad():
for _, data in enumerate(data_loader): for _, data in enumerate(data_loader):
# format data # format data
text_input, text_lengths, mel_input, mel_lengths, _,\ text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
_, _, attn_mask = format_data(data) _, _, attn_mask = format_data(data)
# forward pass model # forward pass model
_ = model.forward( _ = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask) text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
break break
if hasattr(model, 'module'): if hasattr(model, 'module'):
@ -145,9 +149,9 @@ def data_depended_init(model, ap):
def train(model, criterion, optimizer, scheduler, def train(model, criterion, optimizer, scheduler,
ap, global_step, epoch, amp): ap, global_step, epoch, amp, speaker_mapping=None):
data_loader = setup_loader(ap, 1, is_val=False, data_loader = setup_loader(ap, 1, is_val=False,
verbose=(epoch == 0)) verbose=(epoch == 0), speaker_mapping=speaker_mapping)
model.train() model.train()
epoch_time = 0 epoch_time = 0
keep_avg = KeepAverage() keep_avg = KeepAverage()
@ -162,7 +166,7 @@ def train(model, criterion, optimizer, scheduler,
start_time = time.time() start_time = time.time()
# format data # format data
text_input, text_lengths, mel_input, mel_lengths, _,\ text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
avg_text_length, avg_spec_length, attn_mask = format_data(data) avg_text_length, avg_spec_length, attn_mask = format_data(data)
loader_time = time.time() - end_time loader_time = time.time() - end_time
@ -176,7 +180,7 @@ def train(model, criterion, optimizer, scheduler,
# forward pass model # forward pass model
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask) text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
# compute loss # compute loss
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
@ -262,7 +266,7 @@ def train(model, criterion, optimizer, scheduler,
# Diagnostic visualizations # Diagnostic visualizations
# direct pass on model for spec predictions # direct pass on model for spec predictions
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1]) spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
spec_pred = spec_pred.permute(0, 2, 1) spec_pred = spec_pred.permute(0, 2, 1)
gt_spec = mel_input.permute(0, 2, 1) gt_spec = mel_input.permute(0, 2, 1)
const_spec = spec_pred[0].data.cpu().numpy() const_spec = spec_pred[0].data.cpu().numpy()
@ -298,8 +302,8 @@ def train(model, criterion, optimizer, scheduler,
@torch.no_grad() @torch.no_grad()
def evaluate(model, criterion, ap, global_step, epoch): def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping):
data_loader = setup_loader(ap, 1, is_val=True) data_loader = setup_loader(ap, 1, is_val=True, speaker_mapping=speaker_mapping)
model.eval() model.eval()
epoch_time = 0 epoch_time = 0
keep_avg = KeepAverage() keep_avg = KeepAverage()
@ -309,12 +313,12 @@ def evaluate(model, criterion, ap, global_step, epoch):
start_time = time.time() start_time = time.time()
# format data # format data
text_input, text_lengths, mel_input, mel_lengths, _,\ text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
_, _, attn_mask = format_data(data) _, _, attn_mask = format_data(data)
# forward pass model # forward pass model
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask) text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
# compute loss # compute loss
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
@ -356,9 +360,9 @@ def evaluate(model, criterion, ap, global_step, epoch):
# Diagnostic visualizations # Diagnostic visualizations
# direct pass on model for spec predictions # direct pass on model for spec predictions
if hasattr(model, 'module'): if hasattr(model, 'module'):
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1]) spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
else: else:
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1]) spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
spec_pred = spec_pred.permute(0, 2, 1) spec_pred = spec_pred.permute(0, 2, 1)
gt_spec = mel_input.permute(0, 2, 1) gt_spec = mel_input.permute(0, 2, 1)
@ -398,7 +402,17 @@ def evaluate(model, criterion, ap, global_step, epoch):
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
print(" | > Synthesizing test sentences") print(" | > Synthesizing test sentences")
speaker_id = 0 if c.use_speaker_embedding else None if c.use_speaker_embedding:
if c.use_external_speaker_embedding_file:
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding']
speaker_id = None
else:
speaker_id = 0
speaker_embedding = None
else:
speaker_id = None
speaker_embedding = None
style_wav = c.get("style_wav_for_test") style_wav = c.get("style_wav_for_test")
for idx, test_sentence in enumerate(test_sentences): for idx, test_sentence in enumerate(test_sentences):
try: try:
@ -409,6 +423,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
use_cuda, use_cuda,
ap, ap,
speaker_id=speaker_id, speaker_id=speaker_id,
speaker_embedding=speaker_embedding,
style_wav=style_wav, style_wav=style_wav,
truncated=False, truncated=False,
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
@ -459,26 +474,10 @@ def main(args): # pylint: disable=redefined-outer-name
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)] meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
# parse speakers # parse speakers
if c.use_speaker_embedding: num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
speakers = get_speakers(meta_data_train)
if args.restore_path:
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
assert all([speaker in speaker_mapping
for speaker in speakers]), "As of now you, you cannot " \
"introduce new speakers to " \
"a previously trained model."
else:
speaker_mapping = {name: i for i, name in enumerate(speakers)}
save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping)
print("Training with {} speakers: {}".format(num_speakers,
", ".join(speakers)))
else:
num_speakers = 0
# setup model # setup model
model = setup_model(num_chars, num_speakers, c) model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim)
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9) optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
criterion = GlowTTSLoss() criterion = GlowTTSLoss()
@ -540,13 +539,13 @@ def main(args): # pylint: disable=redefined-outer-name
best_loss = float('inf') best_loss = float('inf')
global_step = args.restore_step global_step = args.restore_step
model = data_depended_init(model, ap) model = data_depended_init(model, ap, speaker_mapping)
for epoch in range(0, c.epochs): for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs) c_logger.print_epoch_start(epoch, c.epochs)
train_avg_loss_dict, global_step = train(model, criterion, optimizer, train_avg_loss_dict, global_step = train(model, criterion, optimizer,
scheduler, ap, global_step, scheduler, ap, global_step,
epoch, amp) epoch, amp, speaker_mapping)
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=speaker_mapping)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_loss'] target_loss = train_avg_loss_dict['avg_loss']
if c.run_eval: if c.run_eval:
@ -602,6 +601,7 @@ if __name__ == '__main__':
# setup output paths and read configs # setup output paths and read configs
c = load_config(args.config_path) c = load_config(args.config_path)
# check_config(c) # check_config(c)
check_config_tts(c)
_ = os.path.dirname(os.path.realpath(__file__)) _ = os.path.dirname(os.path.realpath(__file__))
if c.apex_amp_level: if c.apex_amp_level:

View File

@ -22,8 +22,7 @@ from TTS.tts.utils.distribute import (DistributedSampler,
from TTS.tts.utils.generic_utils import setup_model, check_config_tts from TTS.tts.utils.generic_utils import setup_model, check_config_tts
from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping, from TTS.tts.utils.speakers import parse_speakers, load_speaker_mapping
save_speaker_mapping)
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -52,6 +51,7 @@ def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None):
meta_data=meta_data_eval if is_val else meta_data_train, meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap, ap=ap,
tp=c.characters if 'characters' in c.keys() else None, tp=c.characters if 'characters' in c.keys() else None,
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
batch_group_size=0 if is_val else c.batch_group_size * batch_group_size=0 if is_val else c.batch_group_size *
c.batch_size, c.batch_size,
min_seq_len=c.min_seq_len, min_seq_len=c.min_seq_len,
@ -502,42 +502,7 @@ def main(args): # pylint: disable=redefined-outer-name
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)] meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
# parse speakers # parse speakers
if c.use_speaker_embedding: num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
speakers = get_speakers(meta_data_train)
if args.restore_path:
if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
if not speaker_mapping:
print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file")
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
if not speaker_mapping:
raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file")
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
speaker_embedding_dim = None
assert all([speaker in speaker_mapping
for speaker in speakers]), "As of now you, you cannot " \
"introduce new speakers to " \
"a previously trained model."
elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file
raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
else: # if start new train and don't use External Embedding file
speaker_mapping = {name: i for i, name in enumerate(speakers)}
speaker_embedding_dim = None
save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping)
print("Training with {} speakers: {}".format(num_speakers,
", ".join(speakers)))
else:
num_speakers = 0
speaker_embedding_dim = None
speaker_mapping = None
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim) model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim)

View File

@ -51,6 +51,8 @@
// "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" // "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ"
// }, // },
"add_blank": false, // if true add a new token after each token of the sentence. This increases the size of the input sequence, but has considerably improved the prosody of the GlowTTS model.
// DISTRIBUTED TRAINING // DISTRIBUTED TRAINING
"distributed":{ "distributed":{
"backend": "nccl", "backend": "nccl",

View File

@ -51,6 +51,8 @@
// "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" // "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ"
// }, // },
"add_blank": false, // if true add a new token after each token of the sentence. This increases the size of the input sequence, but has considerably improved the prosody of the GlowTTS model.
// DISTRIBUTED TRAINING // DISTRIBUTED TRAINING
"distributed":{ "distributed":{
"backend": "nccl", "backend": "nccl",

View File

@ -17,6 +17,7 @@ class MyDataset(Dataset):
ap, ap,
meta_data, meta_data,
tp=None, tp=None,
add_blank=False,
batch_group_size=0, batch_group_size=0,
min_seq_len=0, min_seq_len=0,
max_seq_len=float("inf"), max_seq_len=float("inf"),
@ -55,6 +56,7 @@ class MyDataset(Dataset):
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.ap = ap self.ap = ap
self.tp = tp self.tp = tp
self.add_blank = add_blank
self.use_phonemes = use_phonemes self.use_phonemes = use_phonemes
self.phoneme_cache_path = phoneme_cache_path self.phoneme_cache_path = phoneme_cache_path
self.phoneme_language = phoneme_language self.phoneme_language = phoneme_language
@ -88,7 +90,7 @@ class MyDataset(Dataset):
phonemes = phoneme_to_sequence(text, [self.cleaners], phonemes = phoneme_to_sequence(text, [self.cleaners],
language=self.phoneme_language, language=self.phoneme_language,
enable_eos_bos=False, enable_eos_bos=False,
tp=self.tp) tp=self.tp, add_blank=self.add_blank)
phonemes = np.asarray(phonemes, dtype=np.int32) phonemes = np.asarray(phonemes, dtype=np.int32)
np.save(cache_path, phonemes) np.save(cache_path, phonemes)
return phonemes return phonemes
@ -127,7 +129,7 @@ class MyDataset(Dataset):
text = self._load_or_generate_phoneme_sequence(wav_file, text) text = self._load_or_generate_phoneme_sequence(wav_file, text)
else: else:
text = np.asarray(text_to_sequence(text, [self.cleaners], text = np.asarray(text_to_sequence(text, [self.cleaners],
tp=self.tp), tp=self.tp, add_blank=self.add_blank),
dtype=np.int32) dtype=np.int32)
assert text.size > 0, self.items[idx][1] assert text.size > 0, self.items[idx][1]

View File

@ -37,7 +37,8 @@ class GlowTts(nn.Module):
hidden_channels_enc=None, hidden_channels_enc=None,
hidden_channels_dec=None, hidden_channels_dec=None,
use_encoder_prenet=False, use_encoder_prenet=False,
encoder_type="transformer"): encoder_type="transformer",
external_speaker_embedding_dim=None):
super().__init__() super().__init__()
self.num_chars = num_chars self.num_chars = num_chars
@ -67,6 +68,14 @@ class GlowTts(nn.Module):
self.use_encoder_prenet = use_encoder_prenet self.use_encoder_prenet = use_encoder_prenet
self.noise_scale = 0.66 self.noise_scale = 0.66
self.length_scale = 1. self.length_scale = 1.
self.external_speaker_embedding_dim = external_speaker_embedding_dim
# if is a multispeaker and c_in_channels is 0, set to 256
if num_speakers > 1:
if self.c_in_channels == 0 and not self.external_speaker_embedding_dim:
self.c_in_channels = 512
elif self.external_speaker_embedding_dim:
self.c_in_channels = self.external_speaker_embedding_dim
self.encoder = Encoder(num_chars, self.encoder = Encoder(num_chars,
out_channels=out_channels, out_channels=out_channels,
@ -80,7 +89,7 @@ class GlowTts(nn.Module):
dropout_p=dropout_p, dropout_p=dropout_p,
mean_only=mean_only, mean_only=mean_only,
use_prenet=use_encoder_prenet, use_prenet=use_encoder_prenet,
c_in_channels=c_in_channels) c_in_channels=self.c_in_channels)
self.decoder = Decoder(out_channels, self.decoder = Decoder(out_channels,
hidden_channels_dec or hidden_channels, hidden_channels_dec or hidden_channels,
@ -92,10 +101,10 @@ class GlowTts(nn.Module):
num_splits=num_splits, num_splits=num_splits,
num_sqz=num_sqz, num_sqz=num_sqz,
sigmoid_scale=sigmoid_scale, sigmoid_scale=sigmoid_scale,
c_in_channels=c_in_channels) c_in_channels=self.c_in_channels)
if num_speakers > 1: if num_speakers > 1 and not external_speaker_embedding_dim:
self.emb_g = nn.Embedding(num_speakers, c_in_channels) self.emb_g = nn.Embedding(num_speakers, self.c_in_channels)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
@staticmethod @staticmethod
@ -122,7 +131,11 @@ class GlowTts(nn.Module):
y_max_length = y.size(2) y_max_length = y.size(2)
# norm speaker embeddings # norm speaker embeddings
if g is not None: if g is not None:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] if self.external_speaker_embedding_dim:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1)# [b, h]
# embedding pass # embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
x_lengths, x_lengths,
@ -157,8 +170,13 @@ class GlowTts(nn.Module):
@torch.no_grad() @torch.no_grad()
def inference(self, x, x_lengths, g=None): def inference(self, x, x_lengths, g=None):
if g is not None: if g is not None:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] if self.external_speaker_embedding_dim:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
# embedding pass # embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
x_lengths, x_lengths,

View File

@ -126,13 +126,15 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
mean_only=True, mean_only=True,
hidden_channels_enc=192, hidden_channels_enc=192,
hidden_channels_dec=192, hidden_channels_dec=192,
use_encoder_prenet=True) use_encoder_prenet=True,
external_speaker_embedding_dim=speaker_embedding_dim)
return model return model
def is_tacotron(c):
return False if c['model'] == 'glow_tts' else True
def check_config_tts(c): def check_config_tts(c):
check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str) check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts'], restricted=True, val_type=str)
check_argument('run_name', c, restricted=True, val_type=str) check_argument('run_name', c, restricted=True, val_type=str)
check_argument('run_description', c, val_type=str) check_argument('run_description', c, val_type=str)
@ -195,27 +197,30 @@ def check_config_tts(c):
check_argument('seq_len_norm', c, restricted=True, val_type=bool) check_argument('seq_len_norm', c, restricted=True, val_type=bool)
# tacotron prenet # tacotron prenet
check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1) check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1)
check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn']) check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn'])
check_argument('prenet_dropout', c, restricted=True, val_type=bool) check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool)
# attention # attention
check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original']) check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original'])
check_argument('attention_heads', c, restricted=True, val_type=int) check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int)
check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax']) check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax'])
check_argument('windowing', c, restricted=True, val_type=bool) check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool)
check_argument('use_forward_attn', c, restricted=True, val_type=bool) check_argument('use_forward_attn', c, restricted=is_tacotron(c), val_type=bool)
check_argument('forward_attn_mask', c, restricted=True, val_type=bool) check_argument('forward_attn_mask', c, restricted=is_tacotron(c), val_type=bool)
check_argument('transition_agent', c, restricted=True, val_type=bool) check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
check_argument('transition_agent', c, restricted=True, val_type=bool) check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
check_argument('location_attn', c, restricted=True, val_type=bool) check_argument('location_attn', c, restricted=is_tacotron(c), val_type=bool)
check_argument('bidirectional_decoder', c, restricted=True, val_type=bool) check_argument('bidirectional_decoder', c, restricted=is_tacotron(c), val_type=bool)
check_argument('double_decoder_consistency', c, restricted=True, val_type=bool) check_argument('double_decoder_consistency', c, restricted=is_tacotron(c), val_type=bool)
check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int) check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
# stopnet # stopnet
check_argument('stopnet', c, restricted=True, val_type=bool) check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool)
check_argument('separate_stopnet', c, restricted=True, val_type=bool) check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool)
# GlowTTS parameters
check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str)
# tensorboard # tensorboard
check_argument('print_step', c, restricted=True, val_type=int, min_val=1) check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
@ -240,15 +245,16 @@ def check_config_tts(c):
# multi-speaker and gst # multi-speaker and gst
check_argument('use_speaker_embedding', c, restricted=True, val_type=bool) check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
check_argument('use_external_speaker_embedding_file', c, restricted=True, val_type=bool) check_argument('use_external_speaker_embedding_file', c, restricted=True if c['use_speaker_embedding'] else False, val_type=bool)
check_argument('external_speaker_embedding_file', c, restricted=True, val_type=str) check_argument('external_speaker_embedding_file', c, restricted=True if c['use_external_speaker_embedding_file'] else False, val_type=str)
check_argument('use_gst', c, restricted=True, val_type=bool) check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool)
check_argument('gst', c, restricted=True, val_type=dict) if c['use_gst']:
check_argument('gst_style_input', c['gst'], restricted=True, val_type=[str, dict]) check_argument('gst', c, restricted=is_tacotron(c), val_type=dict)
check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=0, max_val=1000) check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict])
check_argument('gst_use_speaker_embedding', c['gst'], restricted=True, val_type=bool) check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=2, max_val=10) check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool)
check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1, max_val=1000) check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10)
check_argument('gst_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000)
# datasets - checking only the first entry # datasets - checking only the first entry
check_argument('datasets', c, restricted=True, val_type=list) check_argument('datasets', c, restricted=True, val_type=list)

View File

@ -30,3 +30,44 @@ def get_speakers(items):
"""Returns a sorted, unique list of speakers in a given dataset.""" """Returns a sorted, unique list of speakers in a given dataset."""
speakers = {e[2] for e in items} speakers = {e[2] for e in items}
return sorted(speakers) return sorted(speakers)
def parse_speakers(c, args, meta_data_train, OUT_PATH):
""" Returns number of speakers, speaker embedding shape and speaker mapping"""
if c.use_speaker_embedding:
speakers = get_speakers(meta_data_train)
if args.restore_path:
if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
if not speaker_mapping:
print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file")
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
if not speaker_mapping:
raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file")
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
speaker_embedding_dim = None
assert all([speaker in speaker_mapping
for speaker in speakers]), "As of now you, you cannot " \
"introduce new speakers to " \
"a previously trained model."
elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file
raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
else: # if start new train and don't use External Embedding file
speaker_mapping = {name: i for i, name in enumerate(speakers)}
speaker_embedding_dim = None
save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping)
print("Training with {} speakers: {}".format(len(speakers),
", ".join(speakers)))
else:
num_speakers = 0
speaker_embedding_dim = None
speaker_mapping = None
return num_speakers, speaker_embedding_dim, speaker_mapping

View File

@ -14,10 +14,13 @@ def text_to_seqvec(text, CONFIG):
seq = np.asarray( seq = np.asarray(
phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language, phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language,
CONFIG.enable_eos_bos_chars, CONFIG.enable_eos_bos_chars,
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), tp=CONFIG.characters if 'characters' in CONFIG.keys() else None,
add_blank=CONFIG['add_blank'] if 'add_blank' in CONFIG.keys() else False),
dtype=np.int32) dtype=np.int32)
else: else:
seq = np.asarray(text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), dtype=np.int32) seq = np.asarray(
text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None,
add_blank=CONFIG['add_blank'] if 'add_blank' in CONFIG.keys() else False), dtype=np.int32)
return seq return seq
@ -59,7 +62,7 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
elif 'glow' in CONFIG.model.lower(): elif 'glow' in CONFIG.model.lower():
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths) postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id else speaker_embeddings)
postnet_output = postnet_output.permute(0, 2, 1) postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models. # these only belong to tacotron models.
decoder_output = None decoder_output = None

View File

@ -16,6 +16,8 @@ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
_phonemes_to_id = {s: i for i, s in enumerate(phonemes)} _phonemes_to_id = {s: i for i, s in enumerate(phonemes)}
_id_to_phonemes = {i: s for i, s in enumerate(phonemes)} _id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
_symbols = symbols
_phonemes = phonemes
# Regular expression matching text enclosed in curly braces: # Regular expression matching text enclosed in curly braces:
_CURLY_RE = re.compile(r'(.*?)\{(.+?)\}(.*)') _CURLY_RE = re.compile(r'(.*?)\{(.+?)\}(.*)')
@ -57,6 +59,10 @@ def text2phone(text, language):
return ph return ph
def intersperse(sequence, token):
result = [token] * (len(sequence) * 2 + 1)
result[1::2] = sequence
return result
def pad_with_eos_bos(phoneme_sequence, tp=None): def pad_with_eos_bos(phoneme_sequence, tp=None):
# pylint: disable=global-statement # pylint: disable=global-statement
@ -69,10 +75,9 @@ def pad_with_eos_bos(phoneme_sequence, tp=None):
return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]] return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]]
def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None, add_blank=False):
def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None):
# pylint: disable=global-statement # pylint: disable=global-statement
global _phonemes_to_id global _phonemes_to_id, _phonemes
if tp: if tp:
_, _phonemes = make_symbols(**tp) _, _phonemes = make_symbols(**tp)
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
@ -88,13 +93,17 @@ def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=
# Append EOS char # Append EOS char
if enable_eos_bos: if enable_eos_bos:
sequence = pad_with_eos_bos(sequence, tp=tp) sequence = pad_with_eos_bos(sequence, tp=tp)
if add_blank:
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
return sequence return sequence
def sequence_to_phoneme(sequence, tp=None): def sequence_to_phoneme(sequence, tp=None, add_blank=False):
# pylint: disable=global-statement # pylint: disable=global-statement
'''Converts a sequence of IDs back to a string''' '''Converts a sequence of IDs back to a string'''
global _id_to_phonemes global _id_to_phonemes, _phonemes
if add_blank:
sequence = list(filter(lambda x: x != len(_phonemes), sequence))
result = '' result = ''
if tp: if tp:
_, _phonemes = make_symbols(**tp) _, _phonemes = make_symbols(**tp)
@ -107,7 +116,7 @@ def sequence_to_phoneme(sequence, tp=None):
return result.replace('}{', ' ') return result.replace('}{', ' ')
def text_to_sequence(text, cleaner_names, tp=None): def text_to_sequence(text, cleaner_names, tp=None, add_blank=False):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded The text can optionally have ARPAbet sequences enclosed in curly braces embedded
@ -121,7 +130,7 @@ def text_to_sequence(text, cleaner_names, tp=None):
List of integers corresponding to the symbols in the text List of integers corresponding to the symbols in the text
''' '''
# pylint: disable=global-statement # pylint: disable=global-statement
global _symbol_to_id global _symbol_to_id, _symbols
if tp: if tp:
_symbols, _ = make_symbols(**tp) _symbols, _ = make_symbols(**tp)
_symbol_to_id = {s: i for i, s in enumerate(_symbols)} _symbol_to_id = {s: i for i, s in enumerate(_symbols)}
@ -137,13 +146,19 @@ def text_to_sequence(text, cleaner_names, tp=None):
_clean_text(m.group(1), cleaner_names)) _clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2)) sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3) text = m.group(3)
if add_blank:
sequence = intersperse(sequence, len(_symbols)) # add a blank token (new), whose id number is len(_symbols)
return sequence return sequence
def sequence_to_text(sequence, tp=None): def sequence_to_text(sequence, tp=None, add_blank=False):
'''Converts a sequence of IDs back to a string''' '''Converts a sequence of IDs back to a string'''
# pylint: disable=global-statement # pylint: disable=global-statement
global _id_to_symbol global _id_to_symbol, _symbols
if add_blank:
sequence = list(filter(lambda x: x != len(_symbols), sequence))
if tp: if tp:
_symbols, _ = make_symbols(**tp) _symbols, _ = make_symbols(**tp)
_id_to_symbol = {i: s for i, s in enumerate(_symbols)} _id_to_symbol = {i: s for i, s in enumerate(_symbols)}

View File

@ -11,6 +11,7 @@ from TTS.utils.io import load_config
conf = load_config(os.path.join(get_tests_input_path(), 'test_config.json')) conf = load_config(os.path.join(get_tests_input_path(), 'test_config.json'))
def test_phoneme_to_sequence(): def test_phoneme_to_sequence():
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
text_cleaner = ["phoneme_cleaners"] text_cleaner = ["phoneme_cleaners"]
lang = "en-us" lang = "en-us"
@ -20,7 +21,7 @@ def test_phoneme_to_sequence():
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
gt = "ɹiːsənt ɹɪːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!" gt = "ɹiːsənt ɹɪːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
assert text_hat == text_hat_with_params == gt assert text_hat == text_hat_with_params == gt
# multiple punctuations # multiple punctuations
text = "Be a voice, not an! echo?" text = "Be a voice, not an! echo?"
sequence = phoneme_to_sequence(text, text_cleaner, lang) sequence = phoneme_to_sequence(text, text_cleaner, lang)
@ -87,6 +88,84 @@ def test_phoneme_to_sequence():
print(len(sequence)) print(len(sequence))
assert text_hat == text_hat_with_params == gt assert text_hat == text_hat_with_params == gt
def test_phoneme_to_sequence_with_blank_token():
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
text_cleaner = ["phoneme_cleaners"]
lang = "en-us"
sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "ɹiːsənt ɹɪːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
assert text_hat == text_hat_with_params == gt
# multiple punctuations
text = "Be a voice, not an! echo?"
sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?"
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
# not ending with punctuation
text = "Be a voice, not an! echo"
sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
# original
text = "Be a voice, not an echo!"
sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!"
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
# extra space after the sentence
text = "Be a voice, not an! echo. "
sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ."
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
# extra space after the sentence
text = "Be a voice, not an! echo. "
sequence = phoneme_to_sequence(text, text_cleaner, lang, True)
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~"
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
# padding char
text = "_Be a _voice, not an! echo_"
sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence)
_ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
print(text_hat)
print(len(sequence))
assert text_hat == text_hat_with_params == gt
def test_text2phone(): def test_text2phone():
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!" gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!"