mirror of https://github.com/coqui-ai/TTS.git
Perform stop token prediction to stop the model.
This commit is contained in:
parent
1fa6068e9d
commit
f8d5bbd5d2
|
@ -315,7 +315,7 @@ class Decoder(nn.Module):
|
||||||
if t >= T_decoder:
|
if t >= T_decoder:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if t > 1 and is_end_of_frames(output.view(self.r, -1), alignment, self.eps):
|
if t > 1 and stop_token.sum().item() > 0.7:
|
||||||
break
|
break
|
||||||
elif t > self.max_decoder_steps:
|
elif t > self.max_decoder_steps:
|
||||||
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
||||||
|
@ -326,7 +326,7 @@ class Decoder(nn.Module):
|
||||||
alignments = torch.stack(alignments).transpose(0, 1)
|
alignments = torch.stack(alignments).transpose(0, 1)
|
||||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||||
return outputs, alignments, stop_tokens
|
return outputs, alignments, stop_tokens.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def is_end_of_frames(output, alignment, eps=0.01): # 0.2
|
def is_end_of_frames(output, alignment, eps=0.01): # 0.2
|
||||||
|
|
Loading…
Reference in New Issue