mirror of https://github.com/coqui-ai/TTS.git
Perform stop token prediction to stop the model.
This commit is contained in:
parent
1fa6068e9d
commit
f8d5bbd5d2
|
@ -315,7 +315,7 @@ class Decoder(nn.Module):
|
|||
if t >= T_decoder:
|
||||
break
|
||||
else:
|
||||
if t > 1 and is_end_of_frames(output.view(self.r, -1), alignment, self.eps):
|
||||
if t > 1 and stop_token.sum().item() > 0.7:
|
||||
break
|
||||
elif t > self.max_decoder_steps:
|
||||
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
||||
|
@ -326,7 +326,7 @@ class Decoder(nn.Module):
|
|||
alignments = torch.stack(alignments).transpose(0, 1)
|
||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||
return outputs, alignments, stop_tokens
|
||||
return outputs, alignments, stop_tokens.squeeze()
|
||||
|
||||
|
||||
def is_end_of_frames(output, alignment, eps=0.01): # 0.2
|
||||
|
|
Loading…
Reference in New Issue