inference truncated NEED TO BE TESTED

This commit is contained in:
Eren Golge 2019-03-11 17:40:09 +01:00
parent 5cbe0f83f6
commit b9b79fcf0f
3 changed files with 99 additions and 17 deletions

View File

@ -202,6 +202,7 @@ class Encoder(nn.Module):
num_layers=1, num_layers=1,
batch_first=True, batch_first=True,
bidirectional=True) bidirectional=True)
self.rnn_state = None
def forward(self, x, input_lengths): def forward(self, x, input_lengths):
x = self.convolutions(x) x = self.convolutions(x)
@ -224,6 +225,16 @@ class Encoder(nn.Module):
outputs, _ = self.lstm(x) outputs, _ = self.lstm(x)
return outputs return outputs
def inference_truncated(self, x):
"""
Preserve encoder state for continuous inference
"""
x = self.convolutions(x)
x = x.transpose(1, 2)
self.lstm.flatten_parameters()
outputs, self.rnn_state = self.lstm(x, self.rnn_state)
return outputs
# adapted from https://github.com/NVIDIA/tacotron2/ # adapted from https://github.com/NVIDIA/tacotron2/
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, in_features, inputs_dim, r, attn_win): def __init__(self, in_features, inputs_dim, r, attn_win):
@ -264,31 +275,34 @@ class Decoder(nn.Module):
self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim) self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim)
self.go_frame_init = nn.Embedding(1, self.mel_channels * r) self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim) self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
self.memory_truncated = None
def get_go_frame(self, inputs): def get_go_frame(self, inputs):
B = inputs.size(0) B = inputs.size(0)
memory = self.go_frame_init(inputs.data.new_zeros(B).long()) memory = self.go_frame_init(inputs.data.new_zeros(B).long())
return memory return memory
def _init_states(self, inputs, mask): def _init_states(self, inputs, mask, keep_states=False):
B = inputs.size(0) B = inputs.size(0)
T = inputs.size(1) T = inputs.size(1)
self.attention_hidden = self.attention_rnn_init( if not keep_states:
inputs.data.new_zeros(B).long()) self.attention_hidden = self.attention_rnn_init(
self.attention_cell = Variable( inputs.data.new_zeros(B).long())
inputs.data.new(B, self.attention_rnn_dim).zero_()) self.attention_cell = Variable(
inputs.data.new(B, self.attention_rnn_dim).zero_())
self.decoder_hidden = self.decoder_rnn_inits( self.decoder_hidden = self.decoder_rnn_inits(
inputs.data.new_zeros(B).long()) inputs.data.new_zeros(B).long())
self.decoder_cell = Variable( self.decoder_cell = Variable(
inputs.data.new(B, self.decoder_rnn_dim).zero_()) inputs.data.new(B, self.decoder_rnn_dim).zero_())
self.context = Variable(
inputs.data.new(B, self.encoder_embedding_dim).zero_())
self.attention_weights = Variable(inputs.data.new(B, T).zero_()) self.attention_weights = Variable(inputs.data.new(B, T).zero_())
self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_()) self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_())
self.context = Variable(
inputs.data.new(B, self.encoder_embedding_dim).zero_())
self.inputs = inputs self.inputs = inputs
self.processed_inputs = self.attention_layer.inputs_layer(inputs) self.processed_inputs = self.attention_layer.inputs_layer(inputs)
self.mask = mask self.mask = mask
@ -399,6 +413,44 @@ class Decoder(nn.Module):
return outputs, gate_outputs, alignments return outputs, gate_outputs, alignments
def inference_truncated(self, inputs):
"""
Preserve decoder states for continuous inference
"""
if self.memory_truncated is None:
self.memory_truncated = self.get_go_frame(inputs)
self._init_states(inputs, mask=None, keep_states=False)
else:
self._init_states(inputs, mask=None, keep_states=True)
self.attention_layer.init_win_idx()
outputs, gate_outputs, alignments, t = [], [], [], 0
stop_flags = [False, False]
while True:
memory = self.prenet(self.memory_truncated)
mel_output, gate_output, alignment = self.decode(memory)
gate_output = torch.sigmoid(gate_output.data)
outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output]
alignments += [alignment]
stop_flags[0] = stop_flags[0] or gate_output > 0.5
stop_flags[1] = stop_flags[1] or alignment[0, -2:].sum() > 0.5
if all(stop_flags):
break
elif len(outputs) == self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")
break
self.memory_truncated = mel_output
t += 1
outputs, gate_outputs, alignments = self._parse_outputs(
outputs, gate_outputs, alignments)
return outputs, gate_outputs, alignments
def inference_step(self, inputs, t, memory=None): def inference_step(self, inputs, t, memory=None):
""" """
For debug purposes For debug purposes

View File

@ -46,6 +46,20 @@ class Tacotron2(nn.Module):
encoder_outputs) encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet mel_outputs_postnet = mel_outputs + mel_outputs_postnet
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
def inference_truncated(self, text):
"""
Preserve model states for continuous inference
"""
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated(encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs( mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
mel_outputs, mel_outputs_postnet, alignments) mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens return mel_outputs, mel_outputs_postnet, alignments, stop_tokens

View File

@ -8,19 +8,35 @@ from .visual import visualize
from matplotlib import pylab as plt from matplotlib import pylab as plt
def synthesis(m, s, CONFIG, use_cuda, ap): def synthesis(model, text, CONFIG, use_cuda, ap, truncated=False):
"""Synthesize voice for the given text.
Args:
model (TTS.models): model to synthesize.
text (str): target text
CONFIG (dict): config dictionary to be loaded from config.json.
use_cuda (bool): enable cuda.
ap (TTS.utils.audio.AudioProcessor): audio processor to process
model outputs.
truncated (bool): keep model states after inference. It can be used
for continuous inference at long texts.
"""
text_cleaner = [CONFIG.text_cleaner] text_cleaner = [CONFIG.text_cleaner]
if CONFIG.use_phonemes: if CONFIG.use_phonemes:
seq = np.asarray( seq = np.asarray(
phoneme_to_sequence(s, text_cleaner, CONFIG.phoneme_language), phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language),
dtype=np.int32) dtype=np.int32)
else: else:
seq = np.asarray(text_to_sequence(s, text_cleaner), dtype=np.int32) seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32)
chars_var = torch.from_numpy(seq).unsqueeze(0) chars_var = torch.from_numpy(seq).unsqueeze(0)
if use_cuda: if use_cuda:
chars_var = chars_var.cuda() chars_var = chars_var.cuda()
decoder_output, postnet_output, alignments, stop_tokens = m.inference( if truncated:
chars_var.long()) decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
chars_var.long())
else:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
chars_var.long())
postnet_output = postnet_output[0].data.cpu().numpy() postnet_output = postnet_output[0].data.cpu().numpy()
decoder_output = decoder_output[0].data.cpu().numpy() decoder_output = decoder_output[0].data.cpu().numpy()
alignment = alignments[0].cpu().data.numpy() alignment = alignments[0].cpu().data.numpy()