Merge branch 'dev' of https://github.com/mozilla/TTS into dev

This commit is contained in:
Eren Gölge 2020-05-18 13:12:31 +02:00
commit e55e28bb37
5 changed files with 10 additions and 5 deletions

View File

@ -138,7 +138,7 @@ class GravesAttention(nn.Module):
def init_states(self, inputs):
if self.J is None or inputs.shape[1]+1 > self.J.shape[-1]:
self.J = torch.arange(0, inputs.shape[1]+2).to(inputs.device) + 0.5
self.J = torch.arange(0, inputs.shape[1]+2.0).to(inputs.device) + 0.5
self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device)
self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)

View File

@ -164,16 +164,20 @@ class Synthesizer(object):
sentences = list(filter(None, [s.strip() for s in sentences])) # remove empty sentences
return sentences
def tts(self, text):
def tts(self, text, speaker_id=None):
wavs = []
sens = self.split_into_sentences(text)
print(sens)
speaker_id = id_to_torch(speaker_id)
if speaker_id is not None and self.use_cuda:
speaker_id = speaker_id.cuda()
for sen in sens:
# preprocess the given text
inputs = text_to_seqvec(sen, self.tts_config, self.use_cuda)
# synthesize voice
decoder_output, postnet_output, alignments, _ = run_model(
self.tts_model, inputs, self.tts_config, False, None, None)
self.tts_model, inputs, self.tts_config, False, speaker_id, None)
# convert outputs to numpy
postnet_output, decoder_output, _ = parse_outputs(
postnet_output, decoder_output, alignments)

View File

@ -34,6 +34,7 @@
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
"print_step": 1, // Number of steps to log traning on console.
"output_path": "/media/erogol/data_ssd/Models/libri_tts/speaker_encoder/", // DATASET-RELATED: output path for all training outputs.
"num_loader_workers": 0, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"model": {
"input_dim": 40,
"proj_dim": 128,

View File

@ -44,7 +44,7 @@ def setup_loader(ap, is_val=False, verbose=False):
loader = DataLoader(dataset,
batch_size=c.num_speakers_in_batch,
shuffle=False,
num_workers=0,
num_workers=c.num_loader_workers,
collate_fn=dataset.collate_fn)
return loader

View File

@ -356,7 +356,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
mel_lengths, decoder_backward_output,
alignments, alignment_lengths, text_lengths)
if c.bidirectional_decoder:
keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_backward_loss'].item(),
keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_b_loss'].item(),
'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item()})
if c.ga_alpha > 0:
keep_avg.update_values({'avg_ga_loss': loss_dict['ga_loss'].item()})