mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'state-pass' into dev-tacotron2
This commit is contained in:
commit
4f89029577
|
@ -202,6 +202,7 @@ class Encoder(nn.Module):
|
|||
num_layers=1,
|
||||
batch_first=True,
|
||||
bidirectional=True)
|
||||
self.rnn_state = None
|
||||
|
||||
def forward(self, x, input_lengths):
|
||||
x = self.convolutions(x)
|
||||
|
@ -224,6 +225,16 @@ class Encoder(nn.Module):
|
|||
outputs, _ = self.lstm(x)
|
||||
return outputs
|
||||
|
||||
def inference_truncated(self, x):
|
||||
"""
|
||||
Preserve encoder state for continuous inference
|
||||
"""
|
||||
x = self.convolutions(x)
|
||||
x = x.transpose(1, 2)
|
||||
self.lstm.flatten_parameters()
|
||||
outputs, self.rnn_state = self.lstm(x, self.rnn_state)
|
||||
return outputs
|
||||
|
||||
# adapted from https://github.com/NVIDIA/tacotron2/
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, in_features, inputs_dim, r, attn_win):
|
||||
|
@ -264,31 +275,34 @@ class Decoder(nn.Module):
|
|||
self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim)
|
||||
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
|
||||
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
|
||||
self.memory_truncated = None
|
||||
|
||||
def get_go_frame(self, inputs):
|
||||
B = inputs.size(0)
|
||||
memory = self.go_frame_init(inputs.data.new_zeros(B).long())
|
||||
return memory
|
||||
|
||||
def _init_states(self, inputs, mask):
|
||||
def _init_states(self, inputs, mask, keep_states=False):
|
||||
B = inputs.size(0)
|
||||
T = inputs.size(1)
|
||||
|
||||
self.attention_hidden = self.attention_rnn_init(
|
||||
inputs.data.new_zeros(B).long())
|
||||
self.attention_cell = Variable(
|
||||
inputs.data.new(B, self.attention_rnn_dim).zero_())
|
||||
if not keep_states:
|
||||
self.attention_hidden = self.attention_rnn_init(
|
||||
inputs.data.new_zeros(B).long())
|
||||
self.attention_cell = Variable(
|
||||
inputs.data.new(B, self.attention_rnn_dim).zero_())
|
||||
|
||||
self.decoder_hidden = self.decoder_rnn_inits(
|
||||
inputs.data.new_zeros(B).long())
|
||||
self.decoder_cell = Variable(
|
||||
inputs.data.new(B, self.decoder_rnn_dim).zero_())
|
||||
self.decoder_hidden = self.decoder_rnn_inits(
|
||||
inputs.data.new_zeros(B).long())
|
||||
self.decoder_cell = Variable(
|
||||
inputs.data.new(B, self.decoder_rnn_dim).zero_())
|
||||
|
||||
self.context = Variable(
|
||||
inputs.data.new(B, self.encoder_embedding_dim).zero_())
|
||||
|
||||
self.attention_weights = Variable(inputs.data.new(B, T).zero_())
|
||||
self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_())
|
||||
self.context = Variable(
|
||||
inputs.data.new(B, self.encoder_embedding_dim).zero_())
|
||||
|
||||
|
||||
self.inputs = inputs
|
||||
self.processed_inputs = self.attention_layer.inputs_layer(inputs)
|
||||
self.mask = mask
|
||||
|
@ -399,6 +413,44 @@ class Decoder(nn.Module):
|
|||
|
||||
return outputs, stop_tokens, alignments
|
||||
|
||||
def inference_truncated(self, inputs):
|
||||
"""
|
||||
Preserve decoder states for continuous inference
|
||||
"""
|
||||
if self.memory_truncated is None:
|
||||
self.memory_truncated = self.get_go_frame(inputs)
|
||||
self._init_states(inputs, mask=None, keep_states=False)
|
||||
else:
|
||||
self._init_states(inputs, mask=None, keep_states=True)
|
||||
|
||||
self.attention_layer.init_win_idx()
|
||||
outputs, gate_outputs, alignments, t = [], [], [], 0
|
||||
stop_flags = [False, False]
|
||||
while True:
|
||||
memory = self.prenet(self.memory_truncated)
|
||||
mel_output, gate_output, alignment = self.decode(memory)
|
||||
gate_output = torch.sigmoid(gate_output.data)
|
||||
outputs += [mel_output.squeeze(1)]
|
||||
gate_outputs += [gate_output]
|
||||
alignments += [alignment]
|
||||
|
||||
stop_flags[0] = stop_flags[0] or gate_output > 0.5
|
||||
stop_flags[1] = stop_flags[1] or alignment[0, -2:].sum() > 0.5
|
||||
if all(stop_flags):
|
||||
break
|
||||
elif len(outputs) == self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
break
|
||||
|
||||
self.memory_truncated = mel_output
|
||||
t += 1
|
||||
|
||||
outputs, gate_outputs, alignments = self._parse_outputs(
|
||||
outputs, gate_outputs, alignments)
|
||||
|
||||
return outputs, gate_outputs, alignments
|
||||
|
||||
|
||||
def inference_step(self, inputs, t, memory=None):
|
||||
"""
|
||||
For debug purposes
|
||||
|
|
|
@ -46,6 +46,20 @@ class Tacotron2(nn.Module):
|
|||
encoder_outputs)
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
|
||||
mel_outputs, mel_outputs_postnet, alignments)
|
||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
||||
|
||||
|
||||
def inference_truncated(self, text):
|
||||
"""
|
||||
Preserve model states for continuous inference
|
||||
"""
|
||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
||||
mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated(encoder_outputs)
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
|
||||
mel_outputs, mel_outputs_postnet, alignments)
|
||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
|
@ -66,4 +66,4 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
assert (param != param_ref).any(
|
||||
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||
count, param.shape, param, param_ref)
|
||||
count += 1
|
||||
count += 1
|
||||
|
|
|
@ -22,6 +22,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
def test_train_step(self):
|
||||
input = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
|
||||
input_lengths[-1] = 128
|
||||
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
||||
linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||
|
|
|
@ -8,19 +8,35 @@ from .visual import visualize
|
|||
from matplotlib import pylab as plt
|
||||
|
||||
|
||||
def synthesis(m, s, CONFIG, use_cuda, ap):
|
||||
def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False):
|
||||
"""Synthesize voice for the given text.
|
||||
|
||||
Args:
|
||||
model (TTS.models): model to synthesize.
|
||||
text (str): target text
|
||||
CONFIG (dict): config dictionary to be loaded from config.json.
|
||||
use_cuda (bool): enable cuda.
|
||||
ap (TTS.utils.audio.AudioProcessor): audio processor to process
|
||||
model outputs.
|
||||
truncated (bool): keep model states after inference. It can be used
|
||||
for continuous inference at long texts.
|
||||
"""
|
||||
text_cleaner = [CONFIG.text_cleaner]
|
||||
if CONFIG.use_phonemes:
|
||||
seq = np.asarray(
|
||||
phoneme_to_sequence(s, text_cleaner, CONFIG.phoneme_language),
|
||||
phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language),
|
||||
dtype=np.int32)
|
||||
else:
|
||||
seq = np.asarray(text_to_sequence(s, text_cleaner), dtype=np.int32)
|
||||
seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32)
|
||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||
if use_cuda:
|
||||
chars_var = chars_var.cuda()
|
||||
decoder_output, postnet_output, alignments, stop_tokens = m.inference(
|
||||
chars_var.long())
|
||||
if truncated:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||
chars_var.long())
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
chars_var.long())
|
||||
postnet_output = postnet_output[0].data.cpu().numpy()
|
||||
decoder_output = decoder_output[0].data.cpu().numpy()
|
||||
alignment = alignments[0].cpu().data.numpy()
|
||||
|
|
Loading…
Reference in New Issue