mirror of https://github.com/coqui-ai/TTS.git
a new hacky way to stop generation and test notebook update
This commit is contained in:
parent
02c0e359b3
commit
4b34fe1e3e
|
@ -216,7 +216,7 @@ class Decoder(nn.Module):
|
|||
eps (float): threshold for detecting the end of a sentence.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, memory_dim, r, eps=0.05, mode='train'):
|
||||
def __init__(self, in_features, memory_dim, r, eps=0, mode='train'):
|
||||
super(Decoder, self).__init__()
|
||||
self.mode = mode
|
||||
self.max_decoder_steps = 200
|
||||
|
@ -310,7 +310,7 @@ class Decoder(nn.Module):
|
|||
if t >= T_decoder:
|
||||
break
|
||||
else:
|
||||
if t > 1 and is_end_of_frames(output, self.eps):
|
||||
if t > 1 and is_end_of_frames(output.view(self.r, -1), alignment, self.eps):
|
||||
break
|
||||
elif t > self.max_decoder_steps:
|
||||
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
||||
|
@ -323,5 +323,6 @@ class Decoder(nn.Module):
|
|||
return outputs, alignments
|
||||
|
||||
|
||||
def is_end_of_frames(output, eps=0.2): # 0.2
|
||||
return (output.data <= eps).all()
|
||||
def is_end_of_frames(output, alignment, eps=0.05): # 0.2
|
||||
return ((output.data <= eps).prod(0) > 0).any() \
|
||||
and alignment.data[:, int(alignment.shape[1]/2):].sum() > 0.7
|
||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue