Modularize functions in Tacotron

This commit is contained in:
Eren Golge 2019-03-05 13:25:50 +01:00
parent cc34fe4c7c
commit 1e8fdec084
5 changed files with 121 additions and 91 deletions

View File

@ -29,7 +29,6 @@
"url": "tcp:\/\/localhost:54321" "url": "tcp:\/\/localhost:54321"
}, },
"embedding_size": 256, // Character embedding vector length. You don't need to change it in general.
"text_cleaner": "phoneme_cleaners", "text_cleaner": "phoneme_cleaners",
"epochs": 1000, // total number of epochs to train. "epochs": 1000, // total number of epochs to train.
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.

View File

@ -301,7 +301,8 @@ class Decoder(nn.Module):
memory_size (int): size of the past window. if <= 0 memory_size = r memory_size (int): size of the past window. if <= 0 memory_size = r
""" """
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing): def __init__(self, in_features, memory_dim, r, memory_size,
attn_windowing):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.r = r self.r = r
self.in_features = in_features self.in_features = in_features
@ -309,7 +310,8 @@ class Decoder(nn.Module):
self.memory_size = memory_size if memory_size > 0 else r self.memory_size = memory_size if memory_size > 0 else r
self.memory_dim = memory_dim self.memory_dim = memory_dim
# memory -> |Prenet| -> processed_memory # memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * self.memory_size, out_features=[256, 128]) self.prenet = Prenet(
memory_dim * self.memory_size, out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
self.attention_rnn = AttentionRNNCell( self.attention_rnn = AttentionRNNCell(
out_dim=128, out_dim=128,
@ -360,116 +362,135 @@ class Decoder(nn.Module):
B = inputs.size(0) B = inputs.size(0)
T = inputs.size(1) T = inputs.size(1)
# go frame as zeros matrix # go frame as zeros matrix
initial_memory = self.memory_init(inputs.data.new_zeros(B).long()) self.memory_input = self.memory_init(inputs.data.new_zeros(B).long())
# decoder states # decoder states
attention_rnn_hidden = self.attention_rnn_init(inputs.data.new_zeros(B).long()) self.attention_rnn_hidden = self.attention_rnn_init(
decoder_rnn_hiddens = [ inputs.data.new_zeros(B).long())
self.decoder_rnn_inits(inputs.data.new_tensor([idx]*B).long()) self.decoder_rnn_hiddens = [
self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long())
for idx in range(len(self.decoder_rnns)) for idx in range(len(self.decoder_rnns))
] ]
current_context_vec = inputs.data.new(B, self.in_features).zero_() self.current_context_vec = inputs.data.new(B, self.in_features).zero_()
# attention states # attention states
attention = inputs.data.new(B, T).zero_() self.attention = inputs.data.new(B, T).zero_()
attention_cum = inputs.data.new(B, T).zero_() self.attention_cum = inputs.data.new(B, T).zero_()
return (initial_memory, attention_rnn_hidden, decoder_rnn_hiddens,
current_context_vec, attention, attention_cum)
def forward(self, inputs, memory=None, mask=None): def _parse_outputs(self, outputs, stop_tokens, attentions):
# Back to batch first
attentions = torch.stack(attentions).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
return outputs, stop_tokens, attentions
def decode(self,
inputs,
t,
mask=None):
# Prenet
processed_memory = self.prenet(self.memory_input)
# Attention RNN
attention_cat = torch.cat(
(self.attention.unsqueeze(1), self.attention_cum.unsqueeze(1)), dim=1)
self.attention_rnn_hidden, self.current_context_vec, self.attention = self.attention_rnn(
processed_memory, self.current_context_vec, self.attention_rnn_hidden,
inputs, attention_cat, mask, t)
del attention_cat
self.attention_cum += self.attention
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
torch.cat((self.attention_rnn_hidden, self.current_context_vec), -1))
# Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)):
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
decoder_input, self.decoder_rnn_hiddens[idx])
# Residual connection
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
decoder_output = decoder_input
del decoder_input
# predict mel vectors from decoder vectors
output = self.proj_to_mel(decoder_output)
output = torch.sigmoid(output)
# predict stop token
stopnet_input = torch.cat([decoder_output, output], -1)
del decoder_output
stop_token = self.stopnet(stopnet_input)
return output, stop_token, self.attention
def _update_memory_queue(self, new_memory):
if self.memory_size > 0:
self.memory_input = torch.cat([
self.memory_input[:, self.r * self.memory_dim:].clone(),
new_memory
], dim=-1)
else:
self.memory_input = new_memory
def forward(self, inputs, memory, mask):
""" """
Decoder forward step.
If decoder inputs are not given (e.g., at testing time), as noted in
Tacotron paper, greedy decoding is adapted.
Args: Args:
inputs: Encoder outputs. inputs: Encoder outputs.
memory (None): Decoder memory (autoregression. If None (at eval-time), memory: Decoder memory (autoregression. If None (at eval-time),
decoder outputs are used as decoder inputs. If None, it uses the last decoder outputs are used as decoder inputs. If None, it uses the last
output as the input. output as the input.
mask (None): Attention mask for sequence padding. mask: Attention mask for sequence padding.
Shapes: Shapes:
- inputs: batch x time x encoder_out_dim - inputs: batch x time x encoder_out_dim
- memory: batch x #mel_specs x mel_spec_dim - memory: batch x #mel_specs x mel_spec_dim
""" """
# Run greedy decoding if memory is None # Run greedy decoding if memory is None
greedy = not self.training memory = self._reshape_memory(memory)
if memory is not None:
memory = self._reshape_memory(memory)
T_decoder = memory.size(0)
outputs = [] outputs = []
attentions = [] attentions = []
stop_tokens = [] stop_tokens = []
t = 0 t = 0
memory_input, attention_rnn_hidden, decoder_rnn_hiddens,\ self._init_states(inputs)
current_context_vec, attention, attention_cum = self._init_states(inputs) while len(outputs) < memory.size(0):
while True:
if t > 0: if t > 0:
if memory is None: new_memory = memory[t - 1]
new_memory = outputs[-1] self._update_memory_queue(new_memory)
else: output, stop_token, attention = self.decode(inputs, t, mask)
new_memory = memory[t - 1]
# Queuing if memory size defined else use previous prediction only.
if self.memory_size > 0:
memory_input = torch.cat([memory_input[:, self.r * self.memory_dim:].clone(), new_memory], dim=-1)
else:
memory_input = new_memory
# Prenet
processed_memory = self.prenet(memory_input)
# Attention RNN
attention_cat = torch.cat(
(attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden,
inputs, attention_cat, mask, t)
del attention_cat
attention_cum += attention
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
torch.cat((attention_rnn_hidden, current_context_vec), -1))
# Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)):
decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
decoder_input, decoder_rnn_hiddens[idx])
# Residual connection
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
decoder_output = decoder_input
del decoder_input
# predict mel vectors from decoder vectors
output = self.proj_to_mel(decoder_output)
output = torch.sigmoid(output)
# predict stop token
stopnet_input = torch.cat([decoder_output, output], -1)
del decoder_output
stop_token = self.stopnet(stopnet_input)
del stopnet_input
outputs += [output] outputs += [output]
attentions += [attention] attentions += [attention]
stop_tokens += [stop_token] stop_tokens += [stop_token]
del output
t += 1 t += 1
if memory is not None:
if t >= T_decoder: return self._parse_outputs(outputs, attentions, stop_tokens)
break
else: def inference(self, inputs):
if t > inputs.shape[1] / 4 and (stop_token > 0.6 or """
attention[:, -1].item() > 0.6): Args:
break inputs: Encoder outputs.
elif t > self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps") Shapes:
break - inputs: batch x time x encoder_out_dim
# Back to batch first """
attentions = torch.stack(attentions).transpose(0, 1) outputs = []
outputs = torch.stack(outputs).transpose(0, 1).contiguous() attentions = []
stop_tokens = torch.stack(stop_tokens).transpose(0, 1) stop_tokens = []
return outputs, attentions, stop_tokens t = 0
self._init_states(inputs)
while True:
if t > 0:
new_memory = outputs[-1]
self._update_memory_queue(new_memory)
output, stop_token, attention = self.decode(inputs, t, None)
outputs += [output]
attentions += [attention]
stop_tokens += [stop_token]
t += 1
if t > inputs.shape[1] / 4 and (stop_token > 0.6
or attention[:, -1].item() > 0.6):
break
elif t > self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")
break
return self._parse_outputs(outputs, attentions, stop_tokens)
class StopNet(nn.Module): class StopNet(nn.Module):
r""" r"""
Predicting stop-token in decoder.
Args: Args:
in_features (int): feature dimension of input. in_features (int): feature dimension of input.
""" """

View File

@ -8,21 +8,19 @@ from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG
class Tacotron(nn.Module): class Tacotron(nn.Module):
def __init__(self, def __init__(self,
num_chars, num_chars,
embedding_dim=256,
linear_dim=1025, linear_dim=1025,
mel_dim=80, mel_dim=80,
r=5, r=5,
padding_idx=None, padding_idx=None,
memory_size=5, memory_size=5,
attn_windowing=False): attn_windowing=False):
super(Tacotron, self).__init__() super(Tacotron, self).__init__()
self.r = r self.r = r
self.mel_dim = mel_dim self.mel_dim = mel_dim
self.linear_dim = linear_dim self.linear_dim = linear_dim
self.embedding = nn.Embedding( self.embedding = nn.Embedding(num_chars, 256, padding_idx=padding_idx)
num_chars, embedding_dim, padding_idx=padding_idx)
self.embedding.weight.data.normal_(0, 0.3) self.embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(embedding_dim) self.encoder = Encoder(256)
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_windowing) self.decoder = Decoder(256, mel_dim, r, memory_size, attn_windowing)
self.postnet = PostCBHG(mel_dim) self.postnet = PostCBHG(mel_dim)
self.last_linear = nn.Sequential( self.last_linear = nn.Sequential(
@ -37,9 +35,22 @@ class Tacotron(nn.Module):
# batch x time x dim*r # batch x time x dim*r
mel_outputs, alignments, stop_tokens = self.decoder( mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, mask) encoder_outputs, mel_specs, mask)
# Reshape
# batch x time x dim # batch x time x dim
mel_outputs = mel_outputs.view(B, -1, self.mel_dim) mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
linear_outputs = self.postnet(mel_outputs) linear_outputs = self.postnet(mel_outputs)
linear_outputs = self.last_linear(linear_outputs) linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments, stop_tokens return mel_outputs, linear_outputs, alignments, stop_tokens
def inference(self, characters):
B = characters.size(0)
inputs = self.embedding(characters)
# batch x time x dim
encoder_outputs = self.encoder(inputs)
# batch x time x dim*r
mel_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs)
# batch x time x dim
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
linear_outputs = self.postnet(mel_outputs)
linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments, stop_tokens

View File

@ -390,7 +390,6 @@ def main(args):
num_chars = len(phonemes) if c.use_phonemes else len(symbols) num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = Tacotron( model = Tacotron(
num_chars=num_chars, num_chars=num_chars,
embedding_dim=c.embedding_size,
linear_dim=ap.num_freq, linear_dim=ap.num_freq,
mel_dim=ap.num_mels, mel_dim=ap.num_mels,
r=c.r, r=c.r,

View File

@ -20,7 +20,7 @@ def synthesis(m, s, CONFIG, use_cuda, ap):
chars_var = torch.from_numpy(seq).unsqueeze(0) chars_var = torch.from_numpy(seq).unsqueeze(0)
if use_cuda: if use_cuda:
chars_var = chars_var.cuda() chars_var = chars_var.cuda()
mel_spec, linear_spec, alignments, stop_tokens = m.forward( mel_spec, linear_spec, alignments, stop_tokens = m.inference(
chars_var.long()) chars_var.long())
linear_spec = linear_spec[0].data.cpu().numpy() linear_spec = linear_spec[0].data.cpu().numpy()
mel_spec = mel_spec[0].data.cpu().numpy() mel_spec = mel_spec[0].data.cpu().numpy()