Merge branch 'state-pass' into dev-tacotron2

This commit is contained in:
Eren Golge 2019-03-12 09:52:15 +01:00
commit 4f89029577
5 changed files with 101 additions and 18 deletions

View File

@ -202,6 +202,7 @@ class Encoder(nn.Module):
num_layers=1,
batch_first=True,
bidirectional=True)
self.rnn_state = None
def forward(self, x, input_lengths):
x = self.convolutions(x)
@ -224,6 +225,16 @@ class Encoder(nn.Module):
outputs, _ = self.lstm(x)
return outputs
def inference_truncated(self, x):
"""
Preserve encoder state for continuous inference
"""
x = self.convolutions(x)
x = x.transpose(1, 2)
self.lstm.flatten_parameters()
outputs, self.rnn_state = self.lstm(x, self.rnn_state)
return outputs
# adapted from https://github.com/NVIDIA/tacotron2/
class Decoder(nn.Module):
def __init__(self, in_features, inputs_dim, r, attn_win):
@ -264,31 +275,34 @@ class Decoder(nn.Module):
self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim)
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
self.memory_truncated = None
def get_go_frame(self, inputs):
B = inputs.size(0)
memory = self.go_frame_init(inputs.data.new_zeros(B).long())
return memory
def _init_states(self, inputs, mask):
def _init_states(self, inputs, mask, keep_states=False):
B = inputs.size(0)
T = inputs.size(1)
self.attention_hidden = self.attention_rnn_init(
inputs.data.new_zeros(B).long())
self.attention_cell = Variable(
inputs.data.new(B, self.attention_rnn_dim).zero_())
if not keep_states:
self.attention_hidden = self.attention_rnn_init(
inputs.data.new_zeros(B).long())
self.attention_cell = Variable(
inputs.data.new(B, self.attention_rnn_dim).zero_())
self.decoder_hidden = self.decoder_rnn_inits(
inputs.data.new_zeros(B).long())
self.decoder_cell = Variable(
inputs.data.new(B, self.decoder_rnn_dim).zero_())
self.decoder_hidden = self.decoder_rnn_inits(
inputs.data.new_zeros(B).long())
self.decoder_cell = Variable(
inputs.data.new(B, self.decoder_rnn_dim).zero_())
self.context = Variable(
inputs.data.new(B, self.encoder_embedding_dim).zero_())
self.attention_weights = Variable(inputs.data.new(B, T).zero_())
self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_())
self.context = Variable(
inputs.data.new(B, self.encoder_embedding_dim).zero_())
self.inputs = inputs
self.processed_inputs = self.attention_layer.inputs_layer(inputs)
self.mask = mask
@ -399,6 +413,44 @@ class Decoder(nn.Module):
return outputs, stop_tokens, alignments
def inference_truncated(self, inputs):
"""
Preserve decoder states for continuous inference
"""
if self.memory_truncated is None:
self.memory_truncated = self.get_go_frame(inputs)
self._init_states(inputs, mask=None, keep_states=False)
else:
self._init_states(inputs, mask=None, keep_states=True)
self.attention_layer.init_win_idx()
outputs, gate_outputs, alignments, t = [], [], [], 0
stop_flags = [False, False]
while True:
memory = self.prenet(self.memory_truncated)
mel_output, gate_output, alignment = self.decode(memory)
gate_output = torch.sigmoid(gate_output.data)
outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output]
alignments += [alignment]
stop_flags[0] = stop_flags[0] or gate_output > 0.5
stop_flags[1] = stop_flags[1] or alignment[0, -2:].sum() > 0.5
if all(stop_flags):
break
elif len(outputs) == self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")
break
self.memory_truncated = mel_output
t += 1
outputs, gate_outputs, alignments = self._parse_outputs(
outputs, gate_outputs, alignments)
return outputs, gate_outputs, alignments
def inference_step(self, inputs, t, memory=None):
"""
For debug purposes

View File

@ -46,6 +46,20 @@ class Tacotron2(nn.Module):
encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
def inference_truncated(self, text):
"""
Preserve model states for continuous inference
"""
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated(encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens

View File

@ -66,4 +66,4 @@ class TacotronTrainTest(unittest.TestCase):
assert (param != param_ref).any(
), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref)
count += 1
count += 1

View File

@ -22,6 +22,7 @@ class TacotronTrainTest(unittest.TestCase):
def test_train_step(self):
input = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
input_lengths[-1] = 128
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device)
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)

View File

@ -8,19 +8,35 @@ from .visual import visualize
from matplotlib import pylab as plt
def synthesis(m, s, CONFIG, use_cuda, ap):
def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False):
"""Synthesize voice for the given text.
Args:
model (TTS.models): model to synthesize.
text (str): target text
CONFIG (dict): config dictionary to be loaded from config.json.
use_cuda (bool): enable cuda.
ap (TTS.utils.audio.AudioProcessor): audio processor to process
model outputs.
truncated (bool): keep model states after inference. It can be used
for continuous inference at long texts.
"""
text_cleaner = [CONFIG.text_cleaner]
if CONFIG.use_phonemes:
seq = np.asarray(
phoneme_to_sequence(s, text_cleaner, CONFIG.phoneme_language),
phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language),
dtype=np.int32)
else:
seq = np.asarray(text_to_sequence(s, text_cleaner), dtype=np.int32)
seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32)
chars_var = torch.from_numpy(seq).unsqueeze(0)
if use_cuda:
chars_var = chars_var.cuda()
decoder_output, postnet_output, alignments, stop_tokens = m.inference(
chars_var.long())
if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
chars_var.long())
else:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
chars_var.long())
postnet_output = postnet_output[0].data.cpu().numpy()
decoder_output = decoder_output[0].data.cpu().numpy()
alignment = alignments[0].cpu().data.numpy()