mirror of https://github.com/coqui-ai/TTS.git
renaming
This commit is contained in:
parent
5cbe0f83f6
commit
527567d7ce
|
@ -299,15 +299,15 @@ class Decoder(nn.Module):
|
||||||
memories = memories.transpose(0, 1)
|
memories = memories.transpose(0, 1)
|
||||||
return memories
|
return memories
|
||||||
|
|
||||||
def _parse_outputs(self, outputs, gate_outputs, alignments):
|
def _parse_outputs(self, outputs, stop_tokens, alignments):
|
||||||
alignments = torch.stack(alignments).transpose(0, 1)
|
alignments = torch.stack(alignments).transpose(0, 1)
|
||||||
gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
|
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||||
gate_outputs = gate_outputs.contiguous()
|
stop_tokens = stop_tokens.contiguous()
|
||||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||||
outputs = outputs.view(
|
outputs = outputs.view(
|
||||||
outputs.size(0), -1, self.mel_channels)
|
outputs.size(0), -1, self.mel_channels)
|
||||||
outputs = outputs.transpose(1, 2)
|
outputs = outputs.transpose(1, 2)
|
||||||
return outputs, gate_outputs, alignments
|
return outputs, stop_tokens, alignments
|
||||||
|
|
||||||
def decode(self, memory):
|
def decode(self, memory):
|
||||||
cell_input = torch.cat((memory, self.context), -1)
|
cell_input = torch.cat((memory, self.context), -1)
|
||||||
|
@ -354,36 +354,36 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
self._init_states(inputs, mask=mask)
|
self._init_states(inputs, mask=mask)
|
||||||
|
|
||||||
outputs, gate_outputs, alignments = [], [], []
|
outputs, stop_tokens, alignments = [], [], []
|
||||||
while len(outputs) < memories.size(0) - 1:
|
while len(outputs) < memories.size(0) - 1:
|
||||||
memory = memories[len(outputs)]
|
memory = memories[len(outputs)]
|
||||||
mel_output, gate_output, attention_weights = self.decode(
|
mel_output, stop_token, attention_weights = self.decode(
|
||||||
memory)
|
memory)
|
||||||
outputs += [mel_output.squeeze(1)]
|
outputs += [mel_output.squeeze(1)]
|
||||||
gate_outputs += [gate_output.squeeze(1)]
|
stop_tokens += [stop_token.squeeze(1)]
|
||||||
alignments += [attention_weights]
|
alignments += [attention_weights]
|
||||||
|
|
||||||
outputs, gate_outputs, alignments = self._parse_outputs(
|
outputs, stop_tokens, alignments = self._parse_outputs(
|
||||||
outputs, gate_outputs, alignments)
|
outputs, stop_tokens, alignments)
|
||||||
|
|
||||||
return outputs, gate_outputs, alignments
|
return outputs, stop_tokens, alignments
|
||||||
|
|
||||||
def inference(self, inputs):
|
def inference(self, inputs):
|
||||||
memory = self.get_go_frame(inputs)
|
memory = self.get_go_frame(inputs)
|
||||||
self._init_states(inputs, mask=None)
|
self._init_states(inputs, mask=None)
|
||||||
|
|
||||||
self.attention_layer.init_win_idx()
|
self.attention_layer.init_win_idx()
|
||||||
outputs, gate_outputs, alignments, t = [], [], [], 0
|
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||||
stop_flags = [False, False]
|
stop_flags = [False, False]
|
||||||
while True:
|
while True:
|
||||||
memory = self.prenet(memory)
|
memory = self.prenet(memory)
|
||||||
mel_output, gate_output, alignment = self.decode(memory)
|
mel_output, stop_token, alignment = self.decode(memory)
|
||||||
gate_output = torch.sigmoid(gate_output.data)
|
stop_token = torch.sigmoid(stop_token.data)
|
||||||
outputs += [mel_output.squeeze(1)]
|
outputs += [mel_output.squeeze(1)]
|
||||||
gate_outputs += [gate_output]
|
stop_tokens += [stop_token]
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
|
|
||||||
stop_flags[0] = stop_flags[0] or gate_output > 0.5
|
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
||||||
stop_flags[1] = stop_flags[1] or alignment[0, -2:].sum() > 0.5
|
stop_flags[1] = stop_flags[1] or alignment[0, -2:].sum() > 0.5
|
||||||
if all(stop_flags):
|
if all(stop_flags):
|
||||||
break
|
break
|
||||||
|
@ -394,10 +394,10 @@ class Decoder(nn.Module):
|
||||||
memory = mel_output
|
memory = mel_output
|
||||||
t += 1
|
t += 1
|
||||||
|
|
||||||
outputs, gate_outputs, alignments = self._parse_outputs(
|
outputs, stop_tokens, alignments = self._parse_outputs(
|
||||||
outputs, gate_outputs, alignments)
|
outputs, stop_tokens, alignments)
|
||||||
|
|
||||||
return outputs, gate_outputs, alignments
|
return outputs, stop_tokens, alignments
|
||||||
|
|
||||||
def inference_step(self, inputs, t, memory=None):
|
def inference_step(self, inputs, t, memory=None):
|
||||||
"""
|
"""
|
||||||
|
@ -408,7 +408,7 @@ class Decoder(nn.Module):
|
||||||
self._init_states(inputs, mask=None)
|
self._init_states(inputs, mask=None)
|
||||||
|
|
||||||
memory = self.prenet(memory)
|
memory = self.prenet(memory)
|
||||||
mel_output, gate_output, alignment = self.decode(memory)
|
mel_output, stop_token, alignment = self.decode(memory)
|
||||||
gate_output = torch.sigmoid(gate_output.data)
|
stop_token = torch.sigmoid(stop_token.data)
|
||||||
memory = mel_output
|
memory = mel_output
|
||||||
return mel_output, gate_output, alignment
|
return mel_output, stop_token, alignment
|
||||||
|
|
Loading…
Reference in New Issue