mirror of https://github.com/coqui-ai/TTS.git
control synthesis lenght as an additional stop condition
This commit is contained in:
parent
72cbe545b9
commit
1b68d3cb4e
|
@ -388,7 +388,7 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
self.attention_layer.init_win_idx()
|
self.attention_layer.init_win_idx()
|
||||||
outputs, stop_tokens, alignments, t = [], [], [], 0
|
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||||
stop_flags = [False, False]
|
stop_flags = [False, False, False]
|
||||||
while True:
|
while True:
|
||||||
memory = self.prenet(memory)
|
memory = self.prenet(memory)
|
||||||
mel_output, stop_token, alignment = self.decode(memory)
|
mel_output, stop_token, alignment = self.decode(memory)
|
||||||
|
@ -398,7 +398,8 @@ class Decoder(nn.Module):
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
|
|
||||||
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
||||||
stop_flags[1] = stop_flags[1] or alignment[0, -2:].sum() > 0.5
|
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.5 and t > inputs.shape[1])
|
||||||
|
stop_flags[2] = t > inputs.shape[1]
|
||||||
if all(stop_flags):
|
if all(stop_flags):
|
||||||
break
|
break
|
||||||
elif len(outputs) == self.max_decoder_steps:
|
elif len(outputs) == self.max_decoder_steps:
|
||||||
|
|
Loading…
Reference in New Issue