a new hacky way to stop generation and test notebook update

This commit is contained in:
Eren Golge 2018-04-13 05:09:14 -07:00
parent 02c0e359b3
commit 4b34fe1e3e
2 changed files with 593 additions and 22 deletions

View File

@ -216,7 +216,7 @@ class Decoder(nn.Module):
eps (float): threshold for detecting the end of a sentence.
"""
def __init__(self, in_features, memory_dim, r, eps=0.05, mode='train'):
def __init__(self, in_features, memory_dim, r, eps=0, mode='train'):
super(Decoder, self).__init__()
self.mode = mode
self.max_decoder_steps = 200
@ -310,7 +310,7 @@ class Decoder(nn.Module):
if t >= T_decoder:
break
else:
if t > 1 and is_end_of_frames(output, self.eps):
if t > 1 and is_end_of_frames(output.view(self.r, -1), alignment, self.eps):
break
elif t > self.max_decoder_steps:
print(" !! Decoder stopped with 'max_decoder_steps'. \
@ -323,5 +323,6 @@ class Decoder(nn.Module):
return outputs, alignments
def is_end_of_frames(output, eps=0.2): # 0.2
return (output.data <= eps).all()
def is_end_of_frames(output, alignment, eps=0.05): # 0.2
return ((output.data <= eps).prod(0) > 0).any() \
and alignment.data[:, int(alignment.shape[1]/2):].sum() > 0.7

File diff suppressed because one or more lines are too long