mirror of https://github.com/coqui-ai/TTS.git
Stop token layer on decoder
This commit is contained in:
parent
1d8ad0968c
commit
3ea1a5358d
|
@ -232,6 +232,7 @@ class Decoder(nn.Module):
|
||||||
[nn.GRUCell(256, 256) for _ in range(2)])
|
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||||
# RNN_state -> |Linear| -> mel_spec
|
# RNN_state -> |Linear| -> mel_spec
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||||
|
self.stopnet = nn.Sequential(nn.Dropout(0.2), nn.Linear(memory_dim * self.r, 1), nn.Sigmoid())
|
||||||
|
|
||||||
def forward(self, inputs, memory=None):
|
def forward(self, inputs, memory=None):
|
||||||
"""
|
"""
|
||||||
|
@ -272,6 +273,7 @@ class Decoder(nn.Module):
|
||||||
memory = memory.transpose(0, 1)
|
memory = memory.transpose(0, 1)
|
||||||
outputs = []
|
outputs = []
|
||||||
alignments = []
|
alignments = []
|
||||||
|
stop_tokens = []
|
||||||
t = 0
|
t = 0
|
||||||
memory_input = initial_memory
|
memory_input = initial_memory
|
||||||
while True:
|
while True:
|
||||||
|
@ -297,8 +299,11 @@ class Decoder(nn.Module):
|
||||||
output = decoder_input
|
output = decoder_input
|
||||||
# predict mel vectors from decoder vectors
|
# predict mel vectors from decoder vectors
|
||||||
output = self.proj_to_mel(output)
|
output = self.proj_to_mel(output)
|
||||||
|
# predict stop token
|
||||||
|
stop_token = self.stopnet(output)
|
||||||
outputs += [output]
|
outputs += [output]
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
|
stop_tokens += stop_token
|
||||||
t += 1
|
t += 1
|
||||||
if (not greedy and self.training) or (greedy and memory is not None):
|
if (not greedy and self.training) or (greedy and memory is not None):
|
||||||
if t >= T_decoder:
|
if t >= T_decoder:
|
||||||
|
@ -314,7 +319,8 @@ class Decoder(nn.Module):
|
||||||
# Back to batch first
|
# Back to batch first
|
||||||
alignments = torch.stack(alignments).transpose(0, 1)
|
alignments = torch.stack(alignments).transpose(0, 1)
|
||||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||||
return outputs, alignments
|
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||||
|
return outputs, alignments, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
def is_end_of_frames(output, alignment, eps=0.05): # 0.2
|
def is_end_of_frames(output, alignment, eps=0.05): # 0.2
|
||||||
|
|
Loading…
Reference in New Issue