mirror of https://github.com/coqui-ai/TTS.git
predict stop token from rnn out + mel
This commit is contained in:
parent
d629dafb20
commit
02d72ccbfe
|
@ -232,11 +232,11 @@ class Decoder(nn.Module):
|
||||||
[nn.GRUCell(256, 256) for _ in range(2)])
|
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||||
# RNN_state -> |Linear| -> mel_spec
|
# RNN_state -> |Linear| -> mel_spec
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||||
self.stopnet = nn.Sequential(nn.Linear(memory_dim * self.r, memory_dim),
|
self.stopnet = nn.Sequential(nn.Linear(256 + self.r * memory_dim, memory_dim),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(memory_dim, memory_dim * self.r),
|
nn.Linear(memory_dim, 256 + self.r * memory_dim),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(memory_dim * self.r, 1),
|
nn.Linear(256 + self.r * memory_dim, 1),
|
||||||
nn.Sigmoid())
|
nn.Sigmoid())
|
||||||
|
|
||||||
def forward(self, inputs, memory=None):
|
def forward(self, inputs, memory=None):
|
||||||
|
@ -301,11 +301,12 @@ class Decoder(nn.Module):
|
||||||
decoder_input, decoder_rnn_hiddens[idx])
|
decoder_input, decoder_rnn_hiddens[idx])
|
||||||
# Residual connectinon
|
# Residual connectinon
|
||||||
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
|
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
|
||||||
output = decoder_input
|
decoder_output = decoder_input
|
||||||
# predict mel vectors from decoder vectors
|
# predict mel vectors from decoder vectors
|
||||||
output = self.proj_to_mel(output)
|
output = self.proj_to_mel(decoder_output)
|
||||||
|
stop_input = torch.cat((output, decoder_output), -1)
|
||||||
# predict stop token
|
# predict stop token
|
||||||
stop_token = self.stopnet(output)
|
stop_token = self.stopnet(stop_input)
|
||||||
outputs += [output]
|
outputs += [output]
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
stop_tokens += [stop_token]
|
stop_tokens += [stop_token]
|
||||||
|
@ -326,8 +327,3 @@ class Decoder(nn.Module):
|
||||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||||
return outputs, alignments, stop_tokens
|
return outputs, alignments, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
def is_end_of_frames(output, alignment, eps=0.05): # 0.2
|
|
||||||
return ((output.data <= eps).prod(0) > 0).any() \
|
|
||||||
and alignment.data[:, int(alignment.shape[1]/2):].sum() > 0.7
|
|
||||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue