mirror of https://github.com/coqui-ai/TTS.git
Add a length constraint for test time stop signal, to avoid stopage at a mid point stop sign
This commit is contained in:
parent
40f1a3d3a5
commit
70beccf328
|
@ -311,7 +311,7 @@ class Decoder(nn.Module):
|
|||
if t >= T_decoder:
|
||||
break
|
||||
else:
|
||||
if t > 1 and stop_token > 0.8:
|
||||
if t > inputs.shape[1]/2 and stop_token > 0.8:
|
||||
break
|
||||
elif t > self.max_decoder_steps:
|
||||
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
||||
|
|
|
@ -11,18 +11,9 @@ hop_length = 250
|
|||
def create_speech(m, s, CONFIG, use_cuda, ap):
|
||||
text_cleaner = [CONFIG.text_cleaner]
|
||||
seq = np.array(text_to_sequence(s, text_cleaner))
|
||||
|
||||
# mel = np.zeros([seq.shape[0], CONFIG.num_mels, 1], dtype=np.float32)
|
||||
|
||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||
if use_cuda:
|
||||
chars_var = torch.autograd.Variable(
|
||||
torch.from_numpy(seq), volatile=True).unsqueeze(0).cuda()
|
||||
# mel_var = torch.autograd.Variable(torch.from_numpy(mel).type(torch.cuda.FloatTensor), volatile=True).cuda()
|
||||
else:
|
||||
chars_var = torch.autograd.Variable(
|
||||
torch.from_numpy(seq), volatile=True).unsqueeze(0)
|
||||
# mel_var = torch.autograd.Variable(torch.from_numpy(mel).type(torch.FloatTensor), volatile=True)
|
||||
|
||||
chars_var = chars_var.cuda()
|
||||
mel_out, linear_out, alignments, stop_tokens = m.forward(chars_var)
|
||||
linear_out = linear_out[0].data.cpu().numpy()
|
||||
alignment = alignments[0].cpu().data.numpy()
|
||||
|
|
Loading…
Reference in New Issue