From b9b79fcf0f342cd7ad90df74742013c80de2567c Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 11 Mar 2019 17:40:09 +0100 Subject: [PATCH 1/2] inference truncated NEED TO BE TESTED --- layers/tacotron2.py | 76 ++++++++++++++++++++++++++++++++++++++------- models/tacotron2.py | 14 +++++++++ utils/synthesis.py | 26 +++++++++++++--- 3 files changed, 99 insertions(+), 17 deletions(-) diff --git a/layers/tacotron2.py b/layers/tacotron2.py index a311016e..1447768f 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -202,6 +202,7 @@ class Encoder(nn.Module): num_layers=1, batch_first=True, bidirectional=True) + self.rnn_state = None def forward(self, x, input_lengths): x = self.convolutions(x) @@ -224,6 +225,16 @@ class Encoder(nn.Module): outputs, _ = self.lstm(x) return outputs + def inference_truncated(self, x): + """ + Preserve encoder state for continuous inference + """ + x = self.convolutions(x) + x = x.transpose(1, 2) + self.lstm.flatten_parameters() + outputs, self.rnn_state = self.lstm(x, self.rnn_state) + return outputs + # adapted from https://github.com/NVIDIA/tacotron2/ class Decoder(nn.Module): def __init__(self, in_features, inputs_dim, r, attn_win): @@ -264,31 +275,34 @@ class Decoder(nn.Module): self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim) self.go_frame_init = nn.Embedding(1, self.mel_channels * r) self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim) + self.memory_truncated = None def get_go_frame(self, inputs): B = inputs.size(0) memory = self.go_frame_init(inputs.data.new_zeros(B).long()) return memory - def _init_states(self, inputs, mask): + def _init_states(self, inputs, mask, keep_states=False): B = inputs.size(0) T = inputs.size(1) - self.attention_hidden = self.attention_rnn_init( - inputs.data.new_zeros(B).long()) - self.attention_cell = Variable( - inputs.data.new(B, self.attention_rnn_dim).zero_()) + if not keep_states: + self.attention_hidden = self.attention_rnn_init( + inputs.data.new_zeros(B).long()) + self.attention_cell = Variable( + inputs.data.new(B, self.attention_rnn_dim).zero_()) - self.decoder_hidden = self.decoder_rnn_inits( - inputs.data.new_zeros(B).long()) - self.decoder_cell = Variable( - inputs.data.new(B, self.decoder_rnn_dim).zero_()) + self.decoder_hidden = self.decoder_rnn_inits( + inputs.data.new_zeros(B).long()) + self.decoder_cell = Variable( + inputs.data.new(B, self.decoder_rnn_dim).zero_()) + + self.context = Variable( + inputs.data.new(B, self.encoder_embedding_dim).zero_()) self.attention_weights = Variable(inputs.data.new(B, T).zero_()) self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_()) - self.context = Variable( - inputs.data.new(B, self.encoder_embedding_dim).zero_()) - + self.inputs = inputs self.processed_inputs = self.attention_layer.inputs_layer(inputs) self.mask = mask @@ -399,6 +413,44 @@ class Decoder(nn.Module): return outputs, gate_outputs, alignments + def inference_truncated(self, inputs): + """ + Preserve decoder states for continuous inference + """ + if self.memory_truncated is None: + self.memory_truncated = self.get_go_frame(inputs) + self._init_states(inputs, mask=None, keep_states=False) + else: + self._init_states(inputs, mask=None, keep_states=True) + + self.attention_layer.init_win_idx() + outputs, gate_outputs, alignments, t = [], [], [], 0 + stop_flags = [False, False] + while True: + memory = self.prenet(self.memory_truncated) + mel_output, gate_output, alignment = self.decode(memory) + gate_output = torch.sigmoid(gate_output.data) + outputs += [mel_output.squeeze(1)] + gate_outputs += [gate_output] + alignments += [alignment] + + stop_flags[0] = stop_flags[0] or gate_output > 0.5 + stop_flags[1] = stop_flags[1] or alignment[0, -2:].sum() > 0.5 + if all(stop_flags): + break + elif len(outputs) == self.max_decoder_steps: + print(" | > Decoder stopped with 'max_decoder_steps") + break + + self.memory_truncated = mel_output + t += 1 + + outputs, gate_outputs, alignments = self._parse_outputs( + outputs, gate_outputs, alignments) + + return outputs, gate_outputs, alignments + + def inference_step(self, inputs, t, memory=None): """ For debug purposes diff --git a/models/tacotron2.py b/models/tacotron2.py index c0cbba26..e4082848 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -46,6 +46,20 @@ class Tacotron2(nn.Module): encoder_outputs) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet + mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs( + mel_outputs, mel_outputs_postnet, alignments) + return mel_outputs, mel_outputs_postnet, alignments, stop_tokens + + + def inference_truncated(self, text): + """ + Preserve model states for continuous inference + """ + embedded_inputs = self.embedding(text).transpose(1, 2) + encoder_outputs = self.encoder.inference_truncated(embedded_inputs) + mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated(encoder_outputs) + mel_outputs_postnet = self.postnet(mel_outputs) + mel_outputs_postnet = mel_outputs + mel_outputs_postnet mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs( mel_outputs, mel_outputs_postnet, alignments) return mel_outputs, mel_outputs_postnet, alignments, stop_tokens \ No newline at end of file diff --git a/utils/synthesis.py b/utils/synthesis.py index 2c26e883..ac612e60 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -8,19 +8,35 @@ from .visual import visualize from matplotlib import pylab as plt -def synthesis(m, s, CONFIG, use_cuda, ap): +def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False): + """Synthesize voice for the given text. + + Args: + model (TTS.models): model to synthesize. + text (str): target text + CONFIG (dict): config dictionary to be loaded from config.json. + use_cuda (bool): enable cuda. + ap (TTS.utils.audio.AudioProcessor): audio processor to process + model outputs. + truncated (bool): keep model states after inference. It can be used + for continuous inference at long texts. + """ text_cleaner = [CONFIG.text_cleaner] if CONFIG.use_phonemes: seq = np.asarray( - phoneme_to_sequence(s, text_cleaner, CONFIG.phoneme_language), + phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language), dtype=np.int32) else: - seq = np.asarray(text_to_sequence(s, text_cleaner), dtype=np.int32) + seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32) chars_var = torch.from_numpy(seq).unsqueeze(0) if use_cuda: chars_var = chars_var.cuda() - decoder_output, postnet_output, alignments, stop_tokens = m.inference( - chars_var.long()) + if truncated: + decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( + chars_var.long()) + else: + decoder_output, postnet_output, alignments, stop_tokens = model.inference( + chars_var.long()) postnet_output = postnet_output[0].data.cpu().numpy() decoder_output = decoder_output[0].data.cpu().numpy() alignment = alignments[0].cpu().data.numpy() From 65ffbae23d0e8f73479f4d4fe49b68a06ce503f8 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 12 Mar 2019 09:52:01 +0100 Subject: [PATCH 2/2] test bug fix --- tests/tacotron2_tests.py | 2 +- tests/tacotron_tests.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tacotron2_tests.py b/tests/tacotron2_tests.py index 56c5a1a1..c2f212f9 100644 --- a/tests/tacotron2_tests.py +++ b/tests/tacotron2_tests.py @@ -66,4 +66,4 @@ class TacotronTrainTest(unittest.TestCase): assert (param != param_ref).any( ), "param {} with shape {} not updated!! \n{}\n{}".format( count, param.shape, param, param_ref) - count += 1 \ No newline at end of file + count += 1 diff --git a/tests/tacotron_tests.py b/tests/tacotron_tests.py index 2f76469a..77195594 100644 --- a/tests/tacotron_tests.py +++ b/tests/tacotron_tests.py @@ -22,6 +22,7 @@ class TacotronTrainTest(unittest.TestCase): def test_train_step(self): input = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 129, (8, )).long().to(device) + input_lengths[-1] = 128 mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device) mel_lengths = torch.randint(20, 30, (8, )).long().to(device)