mirror of https://github.com/coqui-ai/TTS.git
fix stop condition
This commit is contained in:
parent
7e5c20500b
commit
e8d29613f1
|
@ -255,7 +255,6 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
outputs, stop_tokens, alignments, t = [], [], [], 0
|
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||||
stop_flags = [True, False, False]
|
stop_flags = [True, False, False]
|
||||||
stop_count = 0
|
|
||||||
while True:
|
while True:
|
||||||
memory = self.prenet(memory)
|
memory = self.prenet(memory)
|
||||||
mel_output, stop_token, alignment = self.decode(memory)
|
mel_output, stop_token, alignment = self.decode(memory)
|
||||||
|
@ -269,9 +268,7 @@ class Decoder(nn.Module):
|
||||||
and t > inputs.shape[1])
|
and t > inputs.shape[1])
|
||||||
stop_flags[2] = t > inputs.shape[1] * 2
|
stop_flags[2] = t > inputs.shape[1] * 2
|
||||||
if all(stop_flags):
|
if all(stop_flags):
|
||||||
stop_count += 1
|
break
|
||||||
if stop_count > 20:
|
|
||||||
break
|
|
||||||
elif len(outputs) == self.max_decoder_steps:
|
elif len(outputs) == self.max_decoder_steps:
|
||||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||||
break
|
break
|
||||||
|
@ -298,7 +295,6 @@ class Decoder(nn.Module):
|
||||||
self.attention.init_states(inputs)
|
self.attention.init_states(inputs)
|
||||||
outputs, stop_tokens, alignments, t = [], [], [], 0
|
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||||
stop_flags = [True, False, False]
|
stop_flags = [True, False, False]
|
||||||
stop_count = 0
|
|
||||||
while True:
|
while True:
|
||||||
memory = self.prenet(self.memory_truncated)
|
memory = self.prenet(self.memory_truncated)
|
||||||
mel_output, stop_token, alignment = self.decode(memory)
|
mel_output, stop_token, alignment = self.decode(memory)
|
||||||
|
@ -312,9 +308,7 @@ class Decoder(nn.Module):
|
||||||
and t > inputs.shape[1])
|
and t > inputs.shape[1])
|
||||||
stop_flags[2] = t > inputs.shape[1] * 2
|
stop_flags[2] = t > inputs.shape[1] * 2
|
||||||
if all(stop_flags):
|
if all(stop_flags):
|
||||||
stop_count += 1
|
break
|
||||||
if stop_count > 20:
|
|
||||||
break
|
|
||||||
elif len(outputs) == self.max_decoder_steps:
|
elif len(outputs) == self.max_decoder_steps:
|
||||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||||
break
|
break
|
||||||
|
|
Loading…
Reference in New Issue