Add a length constraint for test time stop signal, to avoid stopage at a mid point stop sign

This commit is contained in:
Eren Golge 2018-05-16 03:59:47 -07:00
parent 40f1a3d3a5
commit 70beccf328
2 changed files with 3 additions and 12 deletions

View File

@ -311,7 +311,7 @@ class Decoder(nn.Module):
if t >= T_decoder:
break
else:
if t > 1 and stop_token > 0.8:
if t > inputs.shape[1]/2 and stop_token > 0.8:
break
elif t > self.max_decoder_steps:
print(" !! Decoder stopped with 'max_decoder_steps'. \

View File

@ -11,18 +11,9 @@ hop_length = 250
def create_speech(m, s, CONFIG, use_cuda, ap):
text_cleaner = [CONFIG.text_cleaner]
seq = np.array(text_to_sequence(s, text_cleaner))
# mel = np.zeros([seq.shape[0], CONFIG.num_mels, 1], dtype=np.float32)
chars_var = torch.from_numpy(seq).unsqueeze(0)
if use_cuda:
chars_var = torch.autograd.Variable(
torch.from_numpy(seq), volatile=True).unsqueeze(0).cuda()
# mel_var = torch.autograd.Variable(torch.from_numpy(mel).type(torch.cuda.FloatTensor), volatile=True).cuda()
else:
chars_var = torch.autograd.Variable(
torch.from_numpy(seq), volatile=True).unsqueeze(0)
# mel_var = torch.autograd.Variable(torch.from_numpy(mel).type(torch.FloatTensor), volatile=True)
chars_var = chars_var.cuda()
mel_out, linear_out, alignments, stop_tokens = m.forward(chars_var)
linear_out = linear_out[0].data.cpu().numpy()
alignment = alignments[0].cpu().data.numpy()