mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'dev' of https://github.com/mozilla/TTS into dev
This commit is contained in:
commit
e55e28bb37
|
@ -138,7 +138,7 @@ class GravesAttention(nn.Module):
|
||||||
|
|
||||||
def init_states(self, inputs):
|
def init_states(self, inputs):
|
||||||
if self.J is None or inputs.shape[1]+1 > self.J.shape[-1]:
|
if self.J is None or inputs.shape[1]+1 > self.J.shape[-1]:
|
||||||
self.J = torch.arange(0, inputs.shape[1]+2).to(inputs.device) + 0.5
|
self.J = torch.arange(0, inputs.shape[1]+2.0).to(inputs.device) + 0.5
|
||||||
self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device)
|
self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device)
|
||||||
self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)
|
self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)
|
||||||
|
|
||||||
|
|
|
@ -164,16 +164,20 @@ class Synthesizer(object):
|
||||||
sentences = list(filter(None, [s.strip() for s in sentences])) # remove empty sentences
|
sentences = list(filter(None, [s.strip() for s in sentences])) # remove empty sentences
|
||||||
return sentences
|
return sentences
|
||||||
|
|
||||||
def tts(self, text):
|
def tts(self, text, speaker_id=None):
|
||||||
wavs = []
|
wavs = []
|
||||||
sens = self.split_into_sentences(text)
|
sens = self.split_into_sentences(text)
|
||||||
print(sens)
|
print(sens)
|
||||||
|
speaker_id = id_to_torch(speaker_id)
|
||||||
|
if speaker_id is not None and self.use_cuda:
|
||||||
|
speaker_id = speaker_id.cuda()
|
||||||
|
|
||||||
for sen in sens:
|
for sen in sens:
|
||||||
# preprocess the given text
|
# preprocess the given text
|
||||||
inputs = text_to_seqvec(sen, self.tts_config, self.use_cuda)
|
inputs = text_to_seqvec(sen, self.tts_config, self.use_cuda)
|
||||||
# synthesize voice
|
# synthesize voice
|
||||||
decoder_output, postnet_output, alignments, _ = run_model(
|
decoder_output, postnet_output, alignments, _ = run_model(
|
||||||
self.tts_model, inputs, self.tts_config, False, None, None)
|
self.tts_model, inputs, self.tts_config, False, speaker_id, None)
|
||||||
# convert outputs to numpy
|
# convert outputs to numpy
|
||||||
postnet_output, decoder_output, _ = parse_outputs(
|
postnet_output, decoder_output, _ = parse_outputs(
|
||||||
postnet_output, decoder_output, alignments)
|
postnet_output, decoder_output, alignments)
|
||||||
|
|
|
@ -34,6 +34,7 @@
|
||||||
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
|
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
|
||||||
"print_step": 1, // Number of steps to log traning on console.
|
"print_step": 1, // Number of steps to log traning on console.
|
||||||
"output_path": "/media/erogol/data_ssd/Models/libri_tts/speaker_encoder/", // DATASET-RELATED: output path for all training outputs.
|
"output_path": "/media/erogol/data_ssd/Models/libri_tts/speaker_encoder/", // DATASET-RELATED: output path for all training outputs.
|
||||||
|
"num_loader_workers": 0, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||||
"model": {
|
"model": {
|
||||||
"input_dim": 40,
|
"input_dim": 40,
|
||||||
"proj_dim": 128,
|
"proj_dim": 128,
|
||||||
|
|
|
@ -44,7 +44,7 @@ def setup_loader(ap, is_val=False, verbose=False):
|
||||||
loader = DataLoader(dataset,
|
loader = DataLoader(dataset,
|
||||||
batch_size=c.num_speakers_in_batch,
|
batch_size=c.num_speakers_in_batch,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=0,
|
num_workers=c.num_loader_workers,
|
||||||
collate_fn=dataset.collate_fn)
|
collate_fn=dataset.collate_fn)
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
|
2
train.py
2
train.py
|
@ -356,7 +356,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
mel_lengths, decoder_backward_output,
|
mel_lengths, decoder_backward_output,
|
||||||
alignments, alignment_lengths, text_lengths)
|
alignments, alignment_lengths, text_lengths)
|
||||||
if c.bidirectional_decoder:
|
if c.bidirectional_decoder:
|
||||||
keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_backward_loss'].item(),
|
keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_b_loss'].item(),
|
||||||
'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item()})
|
'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item()})
|
||||||
if c.ga_alpha > 0:
|
if c.ga_alpha > 0:
|
||||||
keep_avg.update_values({'avg_ga_loss': loss_dict['ga_loss'].item()})
|
keep_avg.update_values({'avg_ga_loss': loss_dict['ga_loss'].item()})
|
||||||
|
|
Loading…
Reference in New Issue