mirror of https://github.com/coqui-ai/TTS.git
inference truncated NEED TO BE TESTED
This commit is contained in:
parent
5cbe0f83f6
commit
b9b79fcf0f
|
@ -202,6 +202,7 @@ class Encoder(nn.Module):
|
||||||
num_layers=1,
|
num_layers=1,
|
||||||
batch_first=True,
|
batch_first=True,
|
||||||
bidirectional=True)
|
bidirectional=True)
|
||||||
|
self.rnn_state = None
|
||||||
|
|
||||||
def forward(self, x, input_lengths):
|
def forward(self, x, input_lengths):
|
||||||
x = self.convolutions(x)
|
x = self.convolutions(x)
|
||||||
|
@ -224,6 +225,16 @@ class Encoder(nn.Module):
|
||||||
outputs, _ = self.lstm(x)
|
outputs, _ = self.lstm(x)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def inference_truncated(self, x):
|
||||||
|
"""
|
||||||
|
Preserve encoder state for continuous inference
|
||||||
|
"""
|
||||||
|
x = self.convolutions(x)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
self.lstm.flatten_parameters()
|
||||||
|
outputs, self.rnn_state = self.lstm(x, self.rnn_state)
|
||||||
|
return outputs
|
||||||
|
|
||||||
# adapted from https://github.com/NVIDIA/tacotron2/
|
# adapted from https://github.com/NVIDIA/tacotron2/
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(self, in_features, inputs_dim, r, attn_win):
|
def __init__(self, in_features, inputs_dim, r, attn_win):
|
||||||
|
@ -264,31 +275,34 @@ class Decoder(nn.Module):
|
||||||
self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim)
|
self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim)
|
||||||
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
|
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
|
||||||
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
|
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
|
||||||
|
self.memory_truncated = None
|
||||||
|
|
||||||
def get_go_frame(self, inputs):
|
def get_go_frame(self, inputs):
|
||||||
B = inputs.size(0)
|
B = inputs.size(0)
|
||||||
memory = self.go_frame_init(inputs.data.new_zeros(B).long())
|
memory = self.go_frame_init(inputs.data.new_zeros(B).long())
|
||||||
return memory
|
return memory
|
||||||
|
|
||||||
def _init_states(self, inputs, mask):
|
def _init_states(self, inputs, mask, keep_states=False):
|
||||||
B = inputs.size(0)
|
B = inputs.size(0)
|
||||||
T = inputs.size(1)
|
T = inputs.size(1)
|
||||||
|
|
||||||
self.attention_hidden = self.attention_rnn_init(
|
if not keep_states:
|
||||||
inputs.data.new_zeros(B).long())
|
self.attention_hidden = self.attention_rnn_init(
|
||||||
self.attention_cell = Variable(
|
inputs.data.new_zeros(B).long())
|
||||||
inputs.data.new(B, self.attention_rnn_dim).zero_())
|
self.attention_cell = Variable(
|
||||||
|
inputs.data.new(B, self.attention_rnn_dim).zero_())
|
||||||
|
|
||||||
self.decoder_hidden = self.decoder_rnn_inits(
|
self.decoder_hidden = self.decoder_rnn_inits(
|
||||||
inputs.data.new_zeros(B).long())
|
inputs.data.new_zeros(B).long())
|
||||||
self.decoder_cell = Variable(
|
self.decoder_cell = Variable(
|
||||||
inputs.data.new(B, self.decoder_rnn_dim).zero_())
|
inputs.data.new(B, self.decoder_rnn_dim).zero_())
|
||||||
|
|
||||||
|
self.context = Variable(
|
||||||
|
inputs.data.new(B, self.encoder_embedding_dim).zero_())
|
||||||
|
|
||||||
self.attention_weights = Variable(inputs.data.new(B, T).zero_())
|
self.attention_weights = Variable(inputs.data.new(B, T).zero_())
|
||||||
self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_())
|
self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_())
|
||||||
self.context = Variable(
|
|
||||||
inputs.data.new(B, self.encoder_embedding_dim).zero_())
|
|
||||||
|
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
self.processed_inputs = self.attention_layer.inputs_layer(inputs)
|
self.processed_inputs = self.attention_layer.inputs_layer(inputs)
|
||||||
self.mask = mask
|
self.mask = mask
|
||||||
|
@ -399,6 +413,44 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
return outputs, gate_outputs, alignments
|
return outputs, gate_outputs, alignments
|
||||||
|
|
||||||
|
def inference_truncated(self, inputs):
|
||||||
|
"""
|
||||||
|
Preserve decoder states for continuous inference
|
||||||
|
"""
|
||||||
|
if self.memory_truncated is None:
|
||||||
|
self.memory_truncated = self.get_go_frame(inputs)
|
||||||
|
self._init_states(inputs, mask=None, keep_states=False)
|
||||||
|
else:
|
||||||
|
self._init_states(inputs, mask=None, keep_states=True)
|
||||||
|
|
||||||
|
self.attention_layer.init_win_idx()
|
||||||
|
outputs, gate_outputs, alignments, t = [], [], [], 0
|
||||||
|
stop_flags = [False, False]
|
||||||
|
while True:
|
||||||
|
memory = self.prenet(self.memory_truncated)
|
||||||
|
mel_output, gate_output, alignment = self.decode(memory)
|
||||||
|
gate_output = torch.sigmoid(gate_output.data)
|
||||||
|
outputs += [mel_output.squeeze(1)]
|
||||||
|
gate_outputs += [gate_output]
|
||||||
|
alignments += [alignment]
|
||||||
|
|
||||||
|
stop_flags[0] = stop_flags[0] or gate_output > 0.5
|
||||||
|
stop_flags[1] = stop_flags[1] or alignment[0, -2:].sum() > 0.5
|
||||||
|
if all(stop_flags):
|
||||||
|
break
|
||||||
|
elif len(outputs) == self.max_decoder_steps:
|
||||||
|
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||||
|
break
|
||||||
|
|
||||||
|
self.memory_truncated = mel_output
|
||||||
|
t += 1
|
||||||
|
|
||||||
|
outputs, gate_outputs, alignments = self._parse_outputs(
|
||||||
|
outputs, gate_outputs, alignments)
|
||||||
|
|
||||||
|
return outputs, gate_outputs, alignments
|
||||||
|
|
||||||
|
|
||||||
def inference_step(self, inputs, t, memory=None):
|
def inference_step(self, inputs, t, memory=None):
|
||||||
"""
|
"""
|
||||||
For debug purposes
|
For debug purposes
|
||||||
|
|
|
@ -46,6 +46,20 @@ class Tacotron2(nn.Module):
|
||||||
encoder_outputs)
|
encoder_outputs)
|
||||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||||
|
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
|
||||||
|
mel_outputs, mel_outputs_postnet, alignments)
|
||||||
|
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def inference_truncated(self, text):
|
||||||
|
"""
|
||||||
|
Preserve model states for continuous inference
|
||||||
|
"""
|
||||||
|
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||||
|
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
||||||
|
mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated(encoder_outputs)
|
||||||
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||||
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
|
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
|
||||||
mel_outputs, mel_outputs_postnet, alignments)
|
mel_outputs, mel_outputs_postnet, alignments)
|
||||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
|
@ -8,19 +8,35 @@ from .visual import visualize
|
||||||
from matplotlib import pylab as plt
|
from matplotlib import pylab as plt
|
||||||
|
|
||||||
|
|
||||||
def synthesis(m, s, CONFIG, use_cuda, ap):
|
def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False):
|
||||||
|
"""Synthesize voice for the given text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (TTS.models): model to synthesize.
|
||||||
|
text (str): target text
|
||||||
|
CONFIG (dict): config dictionary to be loaded from config.json.
|
||||||
|
use_cuda (bool): enable cuda.
|
||||||
|
ap (TTS.utils.audio.AudioProcessor): audio processor to process
|
||||||
|
model outputs.
|
||||||
|
truncated (bool): keep model states after inference. It can be used
|
||||||
|
for continuous inference at long texts.
|
||||||
|
"""
|
||||||
text_cleaner = [CONFIG.text_cleaner]
|
text_cleaner = [CONFIG.text_cleaner]
|
||||||
if CONFIG.use_phonemes:
|
if CONFIG.use_phonemes:
|
||||||
seq = np.asarray(
|
seq = np.asarray(
|
||||||
phoneme_to_sequence(s, text_cleaner, CONFIG.phoneme_language),
|
phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language),
|
||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
else:
|
else:
|
||||||
seq = np.asarray(text_to_sequence(s, text_cleaner), dtype=np.int32)
|
seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32)
|
||||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
chars_var = chars_var.cuda()
|
chars_var = chars_var.cuda()
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = m.inference(
|
if truncated:
|
||||||
chars_var.long())
|
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||||
|
chars_var.long())
|
||||||
|
else:
|
||||||
|
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||||
|
chars_var.long())
|
||||||
postnet_output = postnet_output[0].data.cpu().numpy()
|
postnet_output = postnet_output[0].data.cpu().numpy()
|
||||||
decoder_output = decoder_output[0].data.cpu().numpy()
|
decoder_output = decoder_output[0].data.cpu().numpy()
|
||||||
alignment = alignments[0].cpu().data.numpy()
|
alignment = alignments[0].cpu().data.numpy()
|
||||||
|
|
Loading…
Reference in New Issue