mirror of https://github.com/coqui-ai/TTS.git
Add a length constraint for test time stop signal, to avoid stopage at a mid point stop sign
This commit is contained in:
parent
7a64c6e383
commit
e39b1b3275
|
@ -311,7 +311,7 @@ class Decoder(nn.Module):
|
||||||
if t >= T_decoder:
|
if t >= T_decoder:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if t > 1 and stop_token > 0.8:
|
if t > inputs.shape[1]/2 and stop_token > 0.8:
|
||||||
break
|
break
|
||||||
elif t > self.max_decoder_steps:
|
elif t > self.max_decoder_steps:
|
||||||
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
||||||
|
|
|
@ -11,18 +11,9 @@ hop_length = 250
|
||||||
def create_speech(m, s, CONFIG, use_cuda, ap):
|
def create_speech(m, s, CONFIG, use_cuda, ap):
|
||||||
text_cleaner = [CONFIG.text_cleaner]
|
text_cleaner = [CONFIG.text_cleaner]
|
||||||
seq = np.array(text_to_sequence(s, text_cleaner))
|
seq = np.array(text_to_sequence(s, text_cleaner))
|
||||||
|
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||||
# mel = np.zeros([seq.shape[0], CONFIG.num_mels, 1], dtype=np.float32)
|
|
||||||
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
chars_var = torch.autograd.Variable(
|
chars_var = chars_var.cuda()
|
||||||
torch.from_numpy(seq), volatile=True).unsqueeze(0).cuda()
|
|
||||||
# mel_var = torch.autograd.Variable(torch.from_numpy(mel).type(torch.cuda.FloatTensor), volatile=True).cuda()
|
|
||||||
else:
|
|
||||||
chars_var = torch.autograd.Variable(
|
|
||||||
torch.from_numpy(seq), volatile=True).unsqueeze(0)
|
|
||||||
# mel_var = torch.autograd.Variable(torch.from_numpy(mel).type(torch.FloatTensor), volatile=True)
|
|
||||||
|
|
||||||
mel_out, linear_out, alignments, stop_tokens = m.forward(chars_var)
|
mel_out, linear_out, alignments, stop_tokens = m.forward(chars_var)
|
||||||
linear_out = linear_out[0].data.cpu().numpy()
|
linear_out = linear_out[0].data.cpu().numpy()
|
||||||
alignment = alignments[0].cpu().data.numpy()
|
alignment = alignments[0].cpu().data.numpy()
|
||||||
|
|
Loading…
Reference in New Issue