mirror of https://github.com/coqui-ai/TTS.git
add External Embedding per sample instead of nn.Embedding
This commit is contained in:
parent
e265810e8c
commit
8a1c113df6
|
@ -49,7 +49,7 @@ from mozilla_voice_tts.utils.training import (NoamLR, adam_weight_decay,
|
||||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||||
|
|
||||||
|
|
||||||
def setup_loader(ap, r, is_val=False, verbose=False):
|
def setup_loader(ap, r, is_val=False, verbose=False, speaker_mapping=None):
|
||||||
if is_val and not c.run_eval:
|
if is_val and not c.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
|
@ -68,7 +68,8 @@ def setup_loader(ap, r, is_val=False, verbose=False):
|
||||||
use_phonemes=c.use_phonemes,
|
use_phonemes=c.use_phonemes,
|
||||||
phoneme_language=c.phoneme_language,
|
phoneme_language=c.phoneme_language,
|
||||||
enable_eos_bos=c.enable_eos_bos_chars,
|
enable_eos_bos=c.enable_eos_bos_chars,
|
||||||
verbose=verbose)
|
verbose=verbose,
|
||||||
|
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
@ -82,9 +83,8 @@ def setup_loader(ap, r, is_val=False, verbose=False):
|
||||||
pin_memory=False)
|
pin_memory=False)
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
def format_data(data, speaker_mapping=None):
|
||||||
def format_data(data):
|
if speaker_mapping is None and c.use_speaker_embedding and not c.use_external_speaker_embedding_file:
|
||||||
if c.use_speaker_embedding:
|
|
||||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||||
|
|
||||||
# setup input data
|
# setup input data
|
||||||
|
@ -99,13 +99,20 @@ def format_data(data):
|
||||||
avg_spec_length = torch.mean(mel_lengths.float())
|
avg_spec_length = torch.mean(mel_lengths.float())
|
||||||
|
|
||||||
if c.use_speaker_embedding:
|
if c.use_speaker_embedding:
|
||||||
speaker_ids = [
|
if c.use_external_speaker_embedding_file:
|
||||||
speaker_mapping[speaker_name] for speaker_name in speaker_names
|
speaker_embeddings = data[8]
|
||||||
]
|
speaker_ids = None
|
||||||
speaker_ids = torch.LongTensor(speaker_ids)
|
else:
|
||||||
|
speaker_ids = [
|
||||||
|
speaker_mapping[speaker_name] for speaker_name in speaker_names
|
||||||
|
]
|
||||||
|
speaker_ids = torch.LongTensor(speaker_ids)
|
||||||
|
speaker_embeddings = None
|
||||||
else:
|
else:
|
||||||
|
speaker_embeddings = None
|
||||||
speaker_ids = None
|
speaker_ids = None
|
||||||
|
|
||||||
|
|
||||||
# set stop targets view, we predict a single stop token per iteration.
|
# set stop targets view, we predict a single stop token per iteration.
|
||||||
stop_targets = stop_targets.view(text_input.shape[0],
|
stop_targets = stop_targets.view(text_input.shape[0],
|
||||||
stop_targets.size(1) // c.r, -1)
|
stop_targets.size(1) // c.r, -1)
|
||||||
|
@ -122,13 +129,16 @@ def format_data(data):
|
||||||
stop_targets = stop_targets.cuda(non_blocking=True)
|
stop_targets = stop_targets.cuda(non_blocking=True)
|
||||||
if speaker_ids is not None:
|
if speaker_ids is not None:
|
||||||
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
||||||
return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length
|
if speaker_embeddings is not None:
|
||||||
|
speaker_embeddings = speaker_embeddings.cuda(non_blocking=True)
|
||||||
|
|
||||||
|
return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, avg_text_length, avg_spec_length
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, optimizer, optimizer_st, scheduler,
|
def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
ap, global_step, epoch, amp):
|
ap, global_step, epoch, amp, speaker_mapping=None):
|
||||||
data_loader = setup_loader(ap, model.decoder.r, is_val=False,
|
data_loader = setup_loader(ap, model.decoder.r, is_val=False,
|
||||||
verbose=(epoch == 0))
|
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
|
||||||
model.train()
|
model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
keep_avg = KeepAverage()
|
keep_avg = KeepAverage()
|
||||||
|
@ -143,7 +153,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# format data
|
# format data
|
||||||
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length = format_data(data)
|
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, avg_text_length, avg_spec_length = format_data(data, speaker_mapping)
|
||||||
loader_time = time.time() - end_time
|
loader_time = time.time() - end_time
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
@ -158,10 +168,10 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
# forward pass model
|
# forward pass model
|
||||||
if c.bidirectional_decoder or c.double_decoder_consistency:
|
if c.bidirectional_decoder or c.double_decoder_consistency:
|
||||||
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
||||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids)
|
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
||||||
else:
|
else:
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids)
|
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
||||||
decoder_backward_output = None
|
decoder_backward_output = None
|
||||||
alignments_backward = None
|
alignments_backward = None
|
||||||
|
|
||||||
|
@ -312,8 +322,8 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def evaluate(model, criterion, ap, global_step, epoch):
|
def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=None):
|
||||||
data_loader = setup_loader(ap, model.decoder.r, is_val=True)
|
data_loader = setup_loader(ap, model.decoder.r, is_val=True, speaker_mapping=speaker_mapping)
|
||||||
model.eval()
|
model.eval()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
keep_avg = KeepAverage()
|
keep_avg = KeepAverage()
|
||||||
|
@ -323,16 +333,16 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# format data
|
# format data
|
||||||
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(data)
|
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, _, _ = format_data(data, speaker_mapping)
|
||||||
assert mel_input.shape[1] % model.decoder.r == 0
|
assert mel_input.shape[1] % model.decoder.r == 0
|
||||||
|
|
||||||
# forward pass model
|
# forward pass model
|
||||||
if c.bidirectional_decoder or c.double_decoder_consistency:
|
if c.bidirectional_decoder or c.double_decoder_consistency:
|
||||||
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
||||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
||||||
else:
|
else:
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
||||||
decoder_backward_output = None
|
decoder_backward_output = None
|
||||||
alignments_backward = None
|
alignments_backward = None
|
||||||
|
|
||||||
|
@ -494,22 +504,41 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
if c.use_speaker_embedding:
|
if c.use_speaker_embedding:
|
||||||
speakers = get_speakers(meta_data_train)
|
speakers = get_speakers(meta_data_train)
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
prev_out_path = os.path.dirname(args.restore_path)
|
if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file
|
||||||
speaker_mapping = load_speaker_mapping(prev_out_path)
|
prev_out_path = os.path.dirname(args.restore_path)
|
||||||
assert all([speaker in speaker_mapping
|
speaker_mapping = load_speaker_mapping(prev_out_path)
|
||||||
for speaker in speakers]), "As of now you, you cannot " \
|
if not speaker_mapping:
|
||||||
"introduce new speakers to " \
|
print("WARNING: speakers.json speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file")
|
||||||
"a previously trained model."
|
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
|
||||||
else:
|
if not speaker_mapping:
|
||||||
|
raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file")
|
||||||
|
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
|
||||||
|
elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file
|
||||||
|
prev_out_path = os.path.dirname(args.restore_path)
|
||||||
|
speaker_mapping = load_speaker_mapping(prev_out_path)
|
||||||
|
speaker_embedding_dim = None
|
||||||
|
assert all([speaker in speaker_mapping
|
||||||
|
for speaker in speakers]), "As of now you, you cannot " \
|
||||||
|
"introduce new speakers to " \
|
||||||
|
"a previously trained model."
|
||||||
|
elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file
|
||||||
|
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
|
||||||
|
print(speaker_mapping)
|
||||||
|
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
|
||||||
|
elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file
|
||||||
|
raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
|
||||||
|
else: # if start new train and don't use External Embedding file
|
||||||
speaker_mapping = {name: i for i, name in enumerate(speakers)}
|
speaker_mapping = {name: i for i, name in enumerate(speakers)}
|
||||||
|
speaker_embedding_dim = None
|
||||||
save_speaker_mapping(OUT_PATH, speaker_mapping)
|
save_speaker_mapping(OUT_PATH, speaker_mapping)
|
||||||
num_speakers = len(speaker_mapping)
|
num_speakers = len(speaker_mapping)
|
||||||
print("Training with {} speakers: {}".format(num_speakers,
|
print("Training with {} speakers: {}".format(num_speakers,
|
||||||
", ".join(speakers)))
|
", ".join(speakers)))
|
||||||
else:
|
else:
|
||||||
num_speakers = 0
|
num_speakers = 0
|
||||||
|
speaker_embedding_dim = None
|
||||||
|
|
||||||
model = setup_model(num_chars, num_speakers, c)
|
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim)
|
||||||
|
|
||||||
params = set_weight_decay(model, c.wd)
|
params = set_weight_decay(model, c.wd)
|
||||||
optimizer = RAdam(params, lr=c.lr, weight_decay=0)
|
optimizer = RAdam(params, lr=c.lr, weight_decay=0)
|
||||||
|
@ -530,6 +559,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
# setup criterion
|
# setup criterion
|
||||||
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
|
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
|
||||||
|
for name, _ in model.named_parameters():
|
||||||
|
print(name)
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||||
|
@ -592,7 +623,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
print("\n > Number of output frames:", model.decoder.r)
|
print("\n > Number of output frames:", model.decoder.r)
|
||||||
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
||||||
optimizer_st, scheduler, ap,
|
optimizer_st, scheduler, ap,
|
||||||
global_step, epoch, amp)
|
global_step, epoch, amp, speaker_mapping)
|
||||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
||||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
target_loss = train_avg_loss_dict['avg_postnet_loss']
|
target_loss = train_avg_loss_dict['avg_postnet_loss']
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
{
|
{
|
||||||
"model": "Tacotron",
|
"model": "Tacotron2",
|
||||||
"run_name": "ljspeech-ddc-bn",
|
"run_name": "ljspeech-ddc-bn",
|
||||||
"run_description": "tacotron2 with ddc and batch-normalization",
|
"run_description": "tacotron2 with ddc and batch-normalization",
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@
|
||||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||||
|
|
||||||
// DATA LOADING
|
// DATA LOADING
|
||||||
"text_cleaner": "portuguese_cleaners",
|
"text_cleaner": "phoneme_cleaners",
|
||||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||||
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||||
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
||||||
|
@ -131,7 +131,9 @@
|
||||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||||
|
|
||||||
// MULTI-SPEAKER and GST
|
// MULTI-SPEAKER and GST
|
||||||
"use_speaker_embedding": true, // use speaker embedding to enable multi-speaker learning.
|
"use_speaker_embedding": true, // use speaker embedding to enable multi-speaker learning.
|
||||||
|
"use_external_speaker_embedding_file": false, // if true, forces the model to use external embedding per sample instead of nn.embeddings, that is, it supports external embeddings such as those used at: https://arxiv.org/abs /1806.04558
|
||||||
|
"external_speaker_embedding_file": "../../speakers-vctk-en.json", // if not null and use_external_speaker_embedding_file is true, it is used to load a specific embedding file and thus uses these embeddings instead of nn.embeddings, that is, it supports external embeddings such as those used at: https://arxiv.org/abs /1806.04558
|
||||||
"use_gst": true, // use global style tokens
|
"use_gst": true, // use global style tokens
|
||||||
"gst": { // gst parameter if gst is enabled
|
"gst": { // gst parameter if gst is enabled
|
||||||
"gst_style_input": null, // Condition the style input either on a
|
"gst_style_input": null, // Condition the style input either on a
|
||||||
|
|
|
@ -24,6 +24,7 @@ class MyDataset(Dataset):
|
||||||
phoneme_cache_path=None,
|
phoneme_cache_path=None,
|
||||||
phoneme_language="en-us",
|
phoneme_language="en-us",
|
||||||
enable_eos_bos=False,
|
enable_eos_bos=False,
|
||||||
|
speaker_mapping=None,
|
||||||
verbose=False):
|
verbose=False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -58,6 +59,7 @@ class MyDataset(Dataset):
|
||||||
self.phoneme_cache_path = phoneme_cache_path
|
self.phoneme_cache_path = phoneme_cache_path
|
||||||
self.phoneme_language = phoneme_language
|
self.phoneme_language = phoneme_language
|
||||||
self.enable_eos_bos = enable_eos_bos
|
self.enable_eos_bos = enable_eos_bos
|
||||||
|
self.speaker_mapping = speaker_mapping
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
if use_phonemes and not os.path.isdir(phoneme_cache_path):
|
||||||
os.makedirs(phoneme_cache_path, exist_ok=True)
|
os.makedirs(phoneme_cache_path, exist_ok=True)
|
||||||
|
@ -127,7 +129,8 @@ class MyDataset(Dataset):
|
||||||
'text': text,
|
'text': text,
|
||||||
'wav': wav,
|
'wav': wav,
|
||||||
'item_idx': self.items[idx][1],
|
'item_idx': self.items[idx][1],
|
||||||
'speaker_name': speaker_name
|
'speaker_name': speaker_name,
|
||||||
|
'wav_file_name': os.path.basename(wav_file)
|
||||||
}
|
}
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
@ -191,9 +194,15 @@ class MyDataset(Dataset):
|
||||||
batch[idx]['item_idx'] for idx in ids_sorted_decreasing
|
batch[idx]['item_idx'] for idx in ids_sorted_decreasing
|
||||||
]
|
]
|
||||||
text = [batch[idx]['text'] for idx in ids_sorted_decreasing]
|
text = [batch[idx]['text'] for idx in ids_sorted_decreasing]
|
||||||
|
|
||||||
speaker_name = [batch[idx]['speaker_name']
|
speaker_name = [batch[idx]['speaker_name']
|
||||||
for idx in ids_sorted_decreasing]
|
for idx in ids_sorted_decreasing]
|
||||||
|
# get speaker embeddings
|
||||||
|
if self.speaker_mapping is not None:
|
||||||
|
wav_files_names = [batch[idx]['wav_file_name'] for idx in ids_sorted_decreasing]
|
||||||
|
speaker_embedding = [self.speaker_mapping[w]['embedding'] for w in wav_files_names]
|
||||||
|
else:
|
||||||
|
speaker_embedding = None
|
||||||
# compute features
|
# compute features
|
||||||
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||||
|
|
||||||
|
@ -224,6 +233,9 @@ class MyDataset(Dataset):
|
||||||
mel_lengths = torch.LongTensor(mel_lengths)
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
stop_targets = torch.FloatTensor(stop_targets)
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
|
|
||||||
|
if speaker_embedding is not None:
|
||||||
|
speaker_embedding = torch.FloatTensor(speaker_embedding)
|
||||||
|
|
||||||
# compute linear spectrogram
|
# compute linear spectrogram
|
||||||
if self.compute_linear_spec:
|
if self.compute_linear_spec:
|
||||||
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||||
|
@ -234,7 +246,7 @@ class MyDataset(Dataset):
|
||||||
else:
|
else:
|
||||||
linear = None
|
linear = None
|
||||||
return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \
|
return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \
|
||||||
stop_targets, item_idxs
|
stop_targets, item_idxs, speaker_embedding
|
||||||
|
|
||||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
found {}".format(type(batch[0]))))
|
found {}".format(type(batch[0]))))
|
||||||
|
|
|
@ -291,7 +291,7 @@ class Decoder(nn.Module):
|
||||||
def __init__(self, in_channels, frame_channels, r, memory_size, attn_type, attn_windowing,
|
def __init__(self, in_channels, frame_channels, r, memory_size, attn_type, attn_windowing,
|
||||||
attn_norm, prenet_type, prenet_dropout, forward_attn,
|
attn_norm, prenet_type, prenet_dropout, forward_attn,
|
||||||
trans_agent, forward_attn_mask, location_attn, attn_K,
|
trans_agent, forward_attn_mask, location_attn, attn_K,
|
||||||
separate_stopnet, speaker_embedding_dim):
|
separate_stopnet):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
self.r_init = r
|
self.r_init = r
|
||||||
self.r = r
|
self.r = r
|
||||||
|
@ -462,15 +462,12 @@ class Decoder(nn.Module):
|
||||||
t += 1
|
t += 1
|
||||||
return self._parse_outputs(outputs, attentions, stop_tokens)
|
return self._parse_outputs(outputs, attentions, stop_tokens)
|
||||||
|
|
||||||
def inference(self, inputs, speaker_embeddings=None):
|
def inference(self, inputs):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inputs: encoder outputs.
|
inputs: encoder outputs.
|
||||||
speaker_embeddings: speaker vectors.
|
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
- inputs: (B, T, D_out_enc)
|
- inputs: batch x time x encoder_out_dim
|
||||||
- speaker_embeddings: (B, D_embed)
|
|
||||||
"""
|
"""
|
||||||
outputs = []
|
outputs = []
|
||||||
attentions = []
|
attentions = []
|
||||||
|
@ -483,8 +480,6 @@ class Decoder(nn.Module):
|
||||||
if t > 0:
|
if t > 0:
|
||||||
new_memory = outputs[-1]
|
new_memory = outputs[-1]
|
||||||
self._update_memory_input(new_memory)
|
self._update_memory_input(new_memory)
|
||||||
if speaker_embeddings is not None:
|
|
||||||
self.memory_input = torch.cat([self.memory_input, speaker_embeddings], dim=-1)
|
|
||||||
output, stop_token, attention = self.decode(inputs, None)
|
output, stop_token, attention = self.decode(inputs, None)
|
||||||
stop_token = torch.sigmoid(stop_token.data)
|
stop_token = torch.sigmoid(stop_token.data)
|
||||||
outputs += [output]
|
outputs += [output]
|
||||||
|
|
|
@ -147,8 +147,7 @@ class Decoder(nn.Module):
|
||||||
#pylint: disable=attribute-defined-outside-init
|
#pylint: disable=attribute-defined-outside-init
|
||||||
def __init__(self, in_channels, frame_channels, r, attn_type, attn_win, attn_norm,
|
def __init__(self, in_channels, frame_channels, r, attn_type, attn_win, attn_norm,
|
||||||
prenet_type, prenet_dropout, forward_attn, trans_agent,
|
prenet_type, prenet_dropout, forward_attn, trans_agent,
|
||||||
forward_attn_mask, location_attn, attn_K, separate_stopnet,
|
forward_attn_mask, location_attn, attn_K, separate_stopnet):
|
||||||
speaker_embedding_dim):
|
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
self.frame_channels = frame_channels
|
self.frame_channels = frame_channels
|
||||||
self.r_init = r
|
self.r_init = r
|
||||||
|
@ -335,16 +334,14 @@ class Decoder(nn.Module):
|
||||||
outputs, stop_tokens, alignments)
|
outputs, stop_tokens, alignments)
|
||||||
return outputs, alignments, stop_tokens
|
return outputs, alignments, stop_tokens
|
||||||
|
|
||||||
def inference(self, inputs, speaker_embeddings=None):
|
def inference(self, inputs):
|
||||||
r"""Decoder inference without teacher forcing and use
|
r"""Decoder inference without teacher forcing and use
|
||||||
Stopnet to stop decoder.
|
Stopnet to stop decoder.
|
||||||
Args:
|
Args:
|
||||||
inputs: Encoder outputs.
|
inputs: Encoder outputs.
|
||||||
speaker_embeddings: speaker embedding vectors.
|
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
- inputs: (B, T, D_out_enc)
|
- inputs: (B, T, D_out_enc)
|
||||||
- speaker_embeddings: (B, D_embed)
|
|
||||||
- outputs: (B, T_mel, D_mel)
|
- outputs: (B, T_mel, D_mel)
|
||||||
- alignments: (B, T_in, T_out)
|
- alignments: (B, T_in, T_out)
|
||||||
- stop_tokens: (B, T_out)
|
- stop_tokens: (B, T_out)
|
||||||
|
@ -358,8 +355,6 @@ class Decoder(nn.Module):
|
||||||
outputs, stop_tokens, alignments, t = [], [], [], 0
|
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||||
while True:
|
while True:
|
||||||
memory = self.prenet(memory)
|
memory = self.prenet(memory)
|
||||||
if speaker_embeddings is not None:
|
|
||||||
memory = torch.cat([memory, speaker_embeddings], dim=-1)
|
|
||||||
decoder_output, alignment, stop_token = self.decode(memory)
|
decoder_output, alignment, stop_token = self.decode(memory)
|
||||||
stop_token = torch.sigmoid(stop_token.data)
|
stop_token = torch.sigmoid(stop_token.data)
|
||||||
outputs += [decoder_output.squeeze(1)]
|
outputs += [decoder_output.squeeze(1)]
|
||||||
|
|
|
@ -27,6 +27,7 @@ class Tacotron(TacotronAbstract):
|
||||||
bidirectional_decoder=False,
|
bidirectional_decoder=False,
|
||||||
double_decoder_consistency=False,
|
double_decoder_consistency=False,
|
||||||
ddc_r=None,
|
ddc_r=None,
|
||||||
|
speaker_embedding_dim=None,
|
||||||
gst=False,
|
gst=False,
|
||||||
gst_embedding_dim=256,
|
gst_embedding_dim=256,
|
||||||
gst_num_heads=4,
|
gst_num_heads=4,
|
||||||
|
@ -40,39 +41,46 @@ class Tacotron(TacotronAbstract):
|
||||||
location_attn, attn_K, separate_stopnet,
|
location_attn, attn_K, separate_stopnet,
|
||||||
bidirectional_decoder, double_decoder_consistency,
|
bidirectional_decoder, double_decoder_consistency,
|
||||||
ddc_r, gst)
|
ddc_r, gst)
|
||||||
|
|
||||||
|
|
||||||
# init layer dims
|
# init layer dims
|
||||||
decoder_in_features = 256
|
decoder_in_features = 256
|
||||||
encoder_in_features = 256
|
encoder_in_features = 256
|
||||||
speaker_embedding_dim = 256
|
|
||||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
|
||||||
|
|
||||||
|
if speaker_embedding_dim is None:
|
||||||
|
# if speaker_embedding_dim is None we need use the nn.Embedding, with default speaker_embedding_dim
|
||||||
|
self.embeddings_per_sample = False
|
||||||
|
speaker_embedding_dim = 256
|
||||||
|
else:
|
||||||
|
# if speaker_embedding_dim is not None we need use speaker embedding per sample
|
||||||
|
self.embeddings_per_sample = True
|
||||||
|
|
||||||
|
# speaker and gst embeddings is concat in decoder input
|
||||||
if num_speakers > 1:
|
if num_speakers > 1:
|
||||||
decoder_in_features = decoder_in_features + speaker_embedding_dim # add speaker embedding dim
|
decoder_in_features = decoder_in_features + speaker_embedding_dim # add speaker embedding dim
|
||||||
if self.gst:
|
if self.gst:
|
||||||
decoder_in_features = decoder_in_features + gst_embedding_dim # add gst embedding dim
|
decoder_in_features = decoder_in_features + gst_embedding_dim # add gst embedding dim
|
||||||
|
|
||||||
# base model layers
|
# embedding layer
|
||||||
self.embedding = nn.Embedding(num_chars, 256, padding_idx=0)
|
self.embedding = nn.Embedding(num_chars, 256, padding_idx=0)
|
||||||
|
|
||||||
|
# speaker embedding layers
|
||||||
|
if num_speakers > 1:
|
||||||
|
if not self.embeddings_per_sample:
|
||||||
|
self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim)
|
||||||
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
|
|
||||||
|
# base model layers
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(encoder_in_features)
|
self.encoder = Encoder(encoder_in_features)
|
||||||
self.decoder = Decoder(decoder_in_features, decoder_output_dim, r,
|
self.decoder = Decoder(decoder_in_features, decoder_output_dim, r,
|
||||||
memory_size, attn_type, attn_win, attn_norm,
|
memory_size, attn_type, attn_win, attn_norm,
|
||||||
prenet_type, prenet_dropout, forward_attn,
|
prenet_type, prenet_dropout, forward_attn,
|
||||||
trans_agent, forward_attn_mask, location_attn,
|
trans_agent, forward_attn_mask, location_attn,
|
||||||
attn_K, separate_stopnet, proj_speaker_dim)
|
attn_K, separate_stopnet)
|
||||||
self.postnet = PostCBHG(decoder_output_dim)
|
self.postnet = PostCBHG(decoder_output_dim)
|
||||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2,
|
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2,
|
||||||
postnet_output_dim)
|
postnet_output_dim)
|
||||||
# speaker embedding layers
|
|
||||||
if num_speakers > 1:
|
|
||||||
self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim)
|
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
|
||||||
self.speaker_project_mel = nn.Sequential(
|
|
||||||
nn.Linear(speaker_embedding_dim, proj_speaker_dim), nn.Tanh())
|
|
||||||
self.speaker_embeddings = None
|
|
||||||
self.speaker_embeddings_projected = None
|
|
||||||
# global style token layers
|
# global style token layers
|
||||||
if self.gst:
|
if self.gst:
|
||||||
self.gst_layer = GST(num_mel=80,
|
self.gst_layer = GST(num_mel=80,
|
||||||
|
@ -88,10 +96,9 @@ class Tacotron(TacotronAbstract):
|
||||||
decoder_in_features, decoder_output_dim, ddc_r, memory_size,
|
decoder_in_features, decoder_output_dim, ddc_r, memory_size,
|
||||||
attn_type, attn_win, attn_norm, prenet_type, prenet_dropout,
|
attn_type, attn_win, attn_norm, prenet_type, prenet_dropout,
|
||||||
forward_attn, trans_agent, forward_attn_mask, location_attn,
|
forward_attn, trans_agent, forward_attn_mask, location_attn,
|
||||||
attn_K, separate_stopnet, proj_speaker_dim)
|
attn_K, separate_stopnet)
|
||||||
|
|
||||||
|
def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None, speaker_embeddings=None):
|
||||||
def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None):
|
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- characters: B x T_in
|
- characters: B x T_in
|
||||||
|
@ -99,24 +106,27 @@ class Tacotron(TacotronAbstract):
|
||||||
- mel_specs: B x T_out x D
|
- mel_specs: B x T_out x D
|
||||||
- speaker_ids: B x 1
|
- speaker_ids: B x 1
|
||||||
"""
|
"""
|
||||||
self._init_states()
|
|
||||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||||
# B x T_in x embed_dim
|
# B x T_in x embed_dim
|
||||||
inputs = self.embedding(characters)
|
inputs = self.embedding(characters)
|
||||||
# B x speaker_embed_dim
|
|
||||||
if speaker_ids is not None:
|
|
||||||
self.compute_speaker_embedding(speaker_ids)
|
|
||||||
# B x T_in x encoder_in_features
|
# B x T_in x encoder_in_features
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
# sequence masking
|
# sequence masking
|
||||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||||
|
|
||||||
# global style token
|
# global style token
|
||||||
if self.gst:
|
if self.gst:
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
||||||
|
# speaker embedding
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
encoder_outputs = self._concat_speaker_embedding(
|
if not self.embeddings_per_sample:
|
||||||
encoder_outputs, self.speaker_embeddings)
|
# B x 1 x speaker_embed_dim
|
||||||
|
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
||||||
|
else:
|
||||||
|
# B x 1 x speaker_embed_dim
|
||||||
|
speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1)
|
||||||
|
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||||
# decoder_outputs: B x decoder_in_features x T_out
|
# decoder_outputs: B x decoder_in_features x T_out
|
||||||
# alignments: B x T_in x encoder_in_features
|
# alignments: B x T_in x encoder_in_features
|
||||||
# stop_tokens: B x T_in
|
# stop_tokens: B x T_in
|
||||||
|
@ -143,19 +153,22 @@ class Tacotron(TacotronAbstract):
|
||||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, characters, speaker_ids=None, style_mel=None):
|
def inference(self, characters, speaker_ids=None, style_mel=None, speaker_embeddings=None):
|
||||||
inputs = self.embedding(characters)
|
inputs = self.embedding(characters)
|
||||||
self._init_states()
|
|
||||||
if speaker_ids is not None:
|
|
||||||
self.compute_speaker_embedding(speaker_ids)
|
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
if self.gst and style_mel is not None:
|
if self.gst:
|
||||||
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
encoder_outputs = self._concat_speaker_embedding(
|
if not self.embeddings_per_sample:
|
||||||
encoder_outputs, self.speaker_embeddings)
|
# B x 1 x speaker_embed_dim
|
||||||
|
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
||||||
|
else:
|
||||||
|
# B x 1 x speaker_embed_dim
|
||||||
|
speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1)
|
||||||
|
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||||
encoder_outputs, self.speaker_embeddings_projected)
|
encoder_outputs)
|
||||||
postnet_outputs = self.postnet(decoder_outputs)
|
postnet_outputs = self.postnet(decoder_outputs)
|
||||||
postnet_outputs = self.last_linear(postnet_outputs)
|
postnet_outputs = self.last_linear(postnet_outputs)
|
||||||
decoder_outputs = decoder_outputs.transpose(1, 2)
|
decoder_outputs = decoder_outputs.transpose(1, 2)
|
||||||
|
|
|
@ -27,6 +27,7 @@ class Tacotron2(TacotronAbstract):
|
||||||
bidirectional_decoder=False,
|
bidirectional_decoder=False,
|
||||||
double_decoder_consistency=False,
|
double_decoder_consistency=False,
|
||||||
ddc_r=None,
|
ddc_r=None,
|
||||||
|
speaker_embedding_dim=None,
|
||||||
gst=False,
|
gst=False,
|
||||||
gst_embedding_dim=512,
|
gst_embedding_dim=512,
|
||||||
gst_num_heads=4,
|
gst_num_heads=4,
|
||||||
|
@ -41,25 +42,38 @@ class Tacotron2(TacotronAbstract):
|
||||||
ddc_r, gst)
|
ddc_r, gst)
|
||||||
|
|
||||||
# init layer dims
|
# init layer dims
|
||||||
speaker_embedding_dim = 512 if num_speakers > 1 else 0
|
decoder_in_features = 512
|
||||||
gst_embedding_dim = gst_embedding_dim if self.gst else 0
|
encoder_in_features = 512
|
||||||
decoder_in_features = 512+speaker_embedding_dim+gst_embedding_dim
|
|
||||||
encoder_in_features = 512 if num_speakers > 1 else 512
|
if speaker_embedding_dim is None:
|
||||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
# if speaker_embedding_dim is None we need use the nn.Embedding, with default speaker_embedding_dim
|
||||||
|
self.embeddings_per_sample = False
|
||||||
|
speaker_embedding_dim = 512
|
||||||
|
else:
|
||||||
|
# if speaker_embedding_dim is not None we need use speaker embedding per sample
|
||||||
|
self.embeddings_per_sample = True
|
||||||
|
|
||||||
|
# speaker and gst embeddings is concat in decoder input
|
||||||
|
if num_speakers > 1:
|
||||||
|
decoder_in_features = decoder_in_features + speaker_embedding_dim # add speaker embedding dim
|
||||||
|
if self.gst:
|
||||||
|
decoder_in_features = decoder_in_features + gst_embedding_dim # add gst embedding dim
|
||||||
|
|
||||||
# embedding layer
|
# embedding layer
|
||||||
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
||||||
|
|
||||||
# speaker embedding layer
|
# speaker embedding layer
|
||||||
if num_speakers > 1:
|
if num_speakers > 1:
|
||||||
self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim)
|
if not self.embeddings_per_sample:
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim)
|
||||||
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
|
|
||||||
|
# base model layers
|
||||||
self.encoder = Encoder(encoder_in_features)
|
self.encoder = Encoder(encoder_in_features)
|
||||||
self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win,
|
self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win,
|
||||||
attn_norm, prenet_type, prenet_dropout,
|
attn_norm, prenet_type, prenet_dropout,
|
||||||
forward_attn, trans_agent, forward_attn_mask,
|
forward_attn, trans_agent, forward_attn_mask,
|
||||||
location_attn, attn_K, separate_stopnet, proj_speaker_dim)
|
location_attn, attn_K, separate_stopnet)
|
||||||
self.postnet = Postnet(self.postnet_output_dim)
|
self.postnet = Postnet(self.postnet_output_dim)
|
||||||
|
|
||||||
# global style token layers
|
# global style token layers
|
||||||
|
@ -77,7 +91,7 @@ class Tacotron2(TacotronAbstract):
|
||||||
decoder_in_features, self.decoder_output_dim, ddc_r, attn_type,
|
decoder_in_features, self.decoder_output_dim, ddc_r, attn_type,
|
||||||
attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn,
|
attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn,
|
||||||
trans_agent, forward_attn_mask, location_attn, attn_K,
|
trans_agent, forward_attn_mask, location_attn, attn_K,
|
||||||
separate_stopnet, proj_speaker_dim)
|
separate_stopnet)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
|
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
|
||||||
|
@ -85,7 +99,7 @@ class Tacotron2(TacotronAbstract):
|
||||||
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
||||||
return mel_outputs, mel_outputs_postnet, alignments
|
return mel_outputs, mel_outputs_postnet, alignments
|
||||||
|
|
||||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None):
|
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None, speaker_embeddings=None):
|
||||||
# compute mask for padding
|
# compute mask for padding
|
||||||
# B x T_in_max (boolean)
|
# B x T_in_max (boolean)
|
||||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||||
|
@ -99,8 +113,13 @@ class Tacotron2(TacotronAbstract):
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
||||||
|
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
|
if not self.embeddings_per_sample:
|
||||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers)
|
# B x 1 x speaker_embed_dim
|
||||||
|
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
||||||
|
else:
|
||||||
|
# B x 1 x speaker_embed_dim
|
||||||
|
speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1)
|
||||||
|
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||||
|
|
||||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||||
|
|
||||||
|
@ -128,23 +147,18 @@ class Tacotron2(TacotronAbstract):
|
||||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, text, speaker_ids=None, style_mel=None):
|
def inference(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None):
|
||||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||||
|
|
||||||
|
if self.gst:
|
||||||
|
# B x gst_dim
|
||||||
|
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
||||||
|
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
|
if not self.embeddings_per_sample:
|
||||||
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
||||||
if hasattr(self, 'gst'):
|
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||||
# B x gst_dim
|
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
|
||||||
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
|
||||||
else:
|
|
||||||
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
|
||||||
else:
|
|
||||||
if hasattr(self, 'gst'):
|
|
||||||
# B x gst_dim
|
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
|
||||||
|
|
||||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||||
encoder_outputs)
|
encoder_outputs)
|
||||||
|
@ -154,25 +168,21 @@ class Tacotron2(TacotronAbstract):
|
||||||
decoder_outputs, postnet_outputs, alignments)
|
decoder_outputs, postnet_outputs, alignments)
|
||||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||||
|
|
||||||
def inference_truncated(self, text, speaker_ids=None, style_mel=None):
|
def inference_truncated(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None):
|
||||||
"""
|
"""
|
||||||
Preserve model states for continuous inference
|
Preserve model states for continuous inference
|
||||||
"""
|
"""
|
||||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||||
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
||||||
|
|
||||||
|
if self.gst:
|
||||||
|
# B x gst_dim
|
||||||
|
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
||||||
|
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
|
if not self.embeddings_per_sample:
|
||||||
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
||||||
if hasattr(self, 'gst'):
|
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||||
# B x gst_dim
|
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
|
||||||
else:
|
|
||||||
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
|
||||||
else:
|
|
||||||
if hasattr(self, 'gst'):
|
|
||||||
# B x gst_dim
|
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
|
|
||||||
|
|
||||||
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(
|
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(
|
||||||
encoder_outputs)
|
encoder_outputs)
|
||||||
|
|
|
@ -44,7 +44,7 @@ def sequence_mask(sequence_length, max_len=None):
|
||||||
return seq_range_expand < seq_length_expand
|
return seq_range_expand < seq_length_expand
|
||||||
|
|
||||||
|
|
||||||
def setup_model(num_chars, num_speakers, c):
|
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||||
print(" > Using model: {}".format(c.model))
|
print(" > Using model: {}".format(c.model))
|
||||||
MyModel = importlib.import_module('mozilla_voice_tts.tts.models.' + c.model.lower())
|
MyModel = importlib.import_module('mozilla_voice_tts.tts.models.' + c.model.lower())
|
||||||
MyModel = getattr(MyModel, c.model)
|
MyModel = getattr(MyModel, c.model)
|
||||||
|
@ -72,7 +72,8 @@ def setup_model(num_chars, num_speakers, c):
|
||||||
separate_stopnet=c.separate_stopnet,
|
separate_stopnet=c.separate_stopnet,
|
||||||
bidirectional_decoder=c.bidirectional_decoder,
|
bidirectional_decoder=c.bidirectional_decoder,
|
||||||
double_decoder_consistency=c.double_decoder_consistency,
|
double_decoder_consistency=c.double_decoder_consistency,
|
||||||
ddc_r=c.ddc_r)
|
ddc_r=c.ddc_r,
|
||||||
|
speaker_embedding_dim=speaker_embedding_dim)
|
||||||
elif c.model.lower() == "tacotron2":
|
elif c.model.lower() == "tacotron2":
|
||||||
model = MyModel(num_chars=num_chars,
|
model = MyModel(num_chars=num_chars,
|
||||||
num_speakers=num_speakers,
|
num_speakers=num_speakers,
|
||||||
|
@ -96,7 +97,8 @@ def setup_model(num_chars, num_speakers, c):
|
||||||
separate_stopnet=c.separate_stopnet,
|
separate_stopnet=c.separate_stopnet,
|
||||||
bidirectional_decoder=c.bidirectional_decoder,
|
bidirectional_decoder=c.bidirectional_decoder,
|
||||||
double_decoder_consistency=c.double_decoder_consistency,
|
double_decoder_consistency=c.double_decoder_consistency,
|
||||||
ddc_r=c.ddc_r)
|
ddc_r=c.ddc_r,
|
||||||
|
speaker_embedding_dim=speaker_embedding_dim)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -175,7 +177,7 @@ def check_config(c):
|
||||||
check_argument('clip_norm', c['audio'], restricted=True, val_type=bool)
|
check_argument('clip_norm', c['audio'], restricted=True, val_type=bool)
|
||||||
check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000)
|
check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000)
|
||||||
check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0)
|
check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0)
|
||||||
check_argument('spec_gain', c['audio'], restricted=True, val_type=float, min_val=1, max_val=100)
|
check_argument('spec_gain', c['audio'], restricted=True, val_type=[int, float], min_val=1, max_val=100)
|
||||||
check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool)
|
check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool)
|
||||||
check_argument('trim_db', c['audio'], restricted=True, val_type=int)
|
check_argument('trim_db', c['audio'], restricted=True, val_type=int)
|
||||||
|
|
||||||
|
@ -246,10 +248,10 @@ def check_config(c):
|
||||||
# paths
|
# paths
|
||||||
check_argument('output_path', c, restricted=True, val_type=str)
|
check_argument('output_path', c, restricted=True, val_type=str)
|
||||||
|
|
||||||
# multi-speaker
|
# multi-speaker and gst
|
||||||
check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
|
check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
|
||||||
|
check_argument('use_external_speaker_embedding_file', c, restricted=True, val_type=bool)
|
||||||
# GST
|
check_argument('external_speaker_embedding_file', c, restricted=True, val_type=str)
|
||||||
check_argument('use_gst', c, restricted=True, val_type=bool)
|
check_argument('use_gst', c, restricted=True, val_type=bool)
|
||||||
check_argument('gst_style_input', c, restricted=True, val_type=str)
|
check_argument('gst_style_input', c, restricted=True, val_type=str)
|
||||||
check_argument('gst', c, restricted=True, val_type=dict)
|
check_argument('gst', c, restricted=True, val_type=dict)
|
||||||
|
|
|
@ -10,12 +10,15 @@ def make_speakers_json_path(out_path):
|
||||||
def load_speaker_mapping(out_path):
|
def load_speaker_mapping(out_path):
|
||||||
"""Loads speaker mapping if already present."""
|
"""Loads speaker mapping if already present."""
|
||||||
try:
|
try:
|
||||||
with open(make_speakers_json_path(out_path)) as f:
|
if os.path.splitext(out_path)[1] == '.json':
|
||||||
|
json_file = out_path
|
||||||
|
else:
|
||||||
|
json_file = make_speakers_json_path(out_path)
|
||||||
|
with open(json_file) as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def save_speaker_mapping(out_path, speaker_mapping):
|
def save_speaker_mapping(out_path, speaker_mapping):
|
||||||
"""Saves speaker mapping if not yet present."""
|
"""Saves speaker mapping if not yet present."""
|
||||||
speakers_json_path = make_speakers_json_path(out_path)
|
speakers_json_path = make_speakers_json_path(out_path)
|
||||||
|
|
Loading…
Reference in New Issue