mirror of https://github.com/coqui-ai/TTS.git
update tacotron model to return `model_outputs`
This commit is contained in:
parent
bb4deee64c
commit
4e910993f1
|
@ -255,7 +255,7 @@ class Tacotron(TacotronAbstract):
|
||||||
outputs['alignments_backward'] = alignments_backward
|
outputs['alignments_backward'] = alignments_backward
|
||||||
outputs['decoder_outputs_backward'] = decoder_outputs_backward
|
outputs['decoder_outputs_backward'] = decoder_outputs_backward
|
||||||
outputs.update({
|
outputs.update({
|
||||||
'postnet_outputs': postnet_outputs,
|
'model_outputs': postnet_outputs,
|
||||||
'decoder_outputs': decoder_outputs,
|
'decoder_outputs': decoder_outputs,
|
||||||
'alignments': alignments,
|
'alignments': alignments,
|
||||||
'stop_tokens': stop_tokens
|
'stop_tokens': stop_tokens
|
||||||
|
@ -287,7 +287,7 @@ class Tacotron(TacotronAbstract):
|
||||||
postnet_outputs = self.last_linear(postnet_outputs)
|
postnet_outputs = self.last_linear(postnet_outputs)
|
||||||
decoder_outputs = decoder_outputs.transpose(1, 2)
|
decoder_outputs = decoder_outputs.transpose(1, 2)
|
||||||
outputs = {
|
outputs = {
|
||||||
'postnet_outputs': postnet_outputs,
|
'model_outputs': postnet_outputs,
|
||||||
'decoder_outputs': decoder_outputs,
|
'decoder_outputs': decoder_outputs,
|
||||||
'alignments': alignments,
|
'alignments': alignments,
|
||||||
'stop_tokens': stop_tokens
|
'stop_tokens': stop_tokens
|
||||||
|
@ -335,7 +335,7 @@ class Tacotron(TacotronAbstract):
|
||||||
|
|
||||||
# compute loss
|
# compute loss
|
||||||
loss_dict = criterion(
|
loss_dict = criterion(
|
||||||
outputs['postnet_outputs'],
|
outputs['model_outputs'],
|
||||||
outputs['decoder_outputs'],
|
outputs['decoder_outputs'],
|
||||||
mel_input,
|
mel_input,
|
||||||
linear_input,
|
linear_input,
|
||||||
|
@ -355,7 +355,7 @@ class Tacotron(TacotronAbstract):
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_log(self, ap, batch, outputs):
|
def train_log(self, ap, batch, outputs):
|
||||||
postnet_outputs = outputs['postnet_outputs']
|
postnet_outputs = outputs['model_outputs']
|
||||||
alignments = outputs['alignments']
|
alignments = outputs['alignments']
|
||||||
alignments_backward = outputs['alignments_backward']
|
alignments_backward = outputs['alignments_backward']
|
||||||
mel_input = batch['mel_input']
|
mel_input = batch['mel_input']
|
||||||
|
|
|
@ -233,7 +233,7 @@ class Tacotron2(TacotronAbstract):
|
||||||
outputs['alignments_backward'] = alignments_backward
|
outputs['alignments_backward'] = alignments_backward
|
||||||
outputs['decoder_outputs_backward'] = decoder_outputs_backward
|
outputs['decoder_outputs_backward'] = decoder_outputs_backward
|
||||||
outputs.update({
|
outputs.update({
|
||||||
'postnet_outputs': postnet_outputs,
|
'model_outputs': postnet_outputs,
|
||||||
'decoder_outputs': decoder_outputs,
|
'decoder_outputs': decoder_outputs,
|
||||||
'alignments': alignments,
|
'alignments': alignments,
|
||||||
'stop_tokens': stop_tokens
|
'stop_tokens': stop_tokens
|
||||||
|
@ -254,7 +254,7 @@ class Tacotron2(TacotronAbstract):
|
||||||
x_vector = self.speaker_embedding(cond_input['speaker_ids'])[:, None]
|
x_vector = self.speaker_embedding(cond_input['speaker_ids'])[:, None]
|
||||||
x_vector = torch.unsqueeze(x_vector, 0).transpose(1, 2)
|
x_vector = torch.unsqueeze(x_vector, 0).transpose(1, 2)
|
||||||
else:
|
else:
|
||||||
x_vector = cond_input
|
x_vector = cond_input['x_vectors']
|
||||||
|
|
||||||
encoder_outputs = self._concat_speaker_embedding(
|
encoder_outputs = self._concat_speaker_embedding(
|
||||||
encoder_outputs, x_vector)
|
encoder_outputs, x_vector)
|
||||||
|
@ -266,7 +266,7 @@ class Tacotron2(TacotronAbstract):
|
||||||
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
|
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
|
||||||
decoder_outputs, postnet_outputs, alignments)
|
decoder_outputs, postnet_outputs, alignments)
|
||||||
outputs = {
|
outputs = {
|
||||||
'postnet_outputs': postnet_outputs,
|
'model_outputs': postnet_outputs,
|
||||||
'decoder_outputs': decoder_outputs,
|
'decoder_outputs': decoder_outputs,
|
||||||
'alignments': alignments,
|
'alignments': alignments,
|
||||||
'stop_tokens': stop_tokens
|
'stop_tokens': stop_tokens
|
||||||
|
|
Loading…
Reference in New Issue