compute sequence mask in model, add tacotron2 relatedfiles

This commit is contained in:
Eren Golge 2019-03-06 13:14:58 +01:00
parent a2a22d253f
commit b031a65677
4 changed files with 508 additions and 7 deletions

385
layers/tacotron2.py Normal file
View File

@ -0,0 +1,385 @@
from math import sqrt
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
class Linear(nn.Module):
def __init__(self,
in_features,
out_features,
bias=True,
init_gain='linear'):
super(Linear, self).__init__()
self.linear_layer = torch.nn.Linear(
in_features, out_features, bias=bias)
self._init_w(init_gain)
def _init_w(self, init_gain):
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(init_gain))
def forward(self, x):
return self.linear_layer(x)
class Prenet(nn.Module):
def __init__(self, in_features, out_features=[256, 256]):
super(Prenet, self).__init__()
in_features = [in_features] + out_features[:-1]
self.layers = nn.ModuleList([
Linear(in_size, out_size, bias=False)
for (in_size, out_size) in zip(in_features, out_features)
])
def forward(self, x):
for linear in self.layers:
# Prenet uses dropout also at inference time. Otherwise,
# it degrades the inference time attention.
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
return x
class ConvBNBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, nonlinear=None):
super(ConvBNBlock, self).__init__()
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
conv1d = nn.Conv1d(
in_channels, out_channels, kernel_size, padding=padding)
norm = nn.BatchNorm1d(out_channels)
dropout = nn.Dropout(p=0.5)
if nonlinear == 'relu':
self.net = nn.Sequential(conv1d, norm, nn.ReLU(), dropout)
elif nonlinear == 'tanh':
self.net = nn.Sequential(conv1d, norm, nn.Tanh(), dropout)
else:
self.net = nn.Sequential(conv1d, norm, dropout)
def forward(self, x):
output = self.net(x)
return output
class LocationLayer(nn.Module):
def __init__(self, attention_n_filters, attention_kernel_size,
attention_dim):
super(LocationLayer, self).__init__()
self.location_conv = nn.Conv1d(
in_channels=2,
out_channels=attention_n_filters,
kernel_size=31,
stride=1,
padding=(31 - 1) // 2,
bias=False)
self.location_dense = Linear(
attention_n_filters, attention_dim, bias=False, init_gain='tanh')
def forward(self, attention_cat):
processed_attention = self.location_conv(attention_cat)
processed_attention = self.location_dense(
processed_attention.transpose(1, 2))
return processed_attention
class Attention(nn.Module):
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
attention_location_n_filters, attention_location_kernel_size,
windowing):
super(Attention, self).__init__()
self.query_layer = Linear(
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
self.inputs_layer = Linear(
embedding_dim, attention_dim, bias=False, init_gain='tanh')
self.v = Linear(attention_dim, 1, bias=False)
self.location_layer = LocationLayer(attention_location_n_filters,
attention_location_kernel_size,
attention_dim)
self._mask_value = -float("inf")
self.windowing = windowing
if self.windowing:
self.win_back = 1
self.win_front = 3
self.win_idx = None
def init_win_idx(self):
self.win_idx = 0
def get_attention(self, query, processed_inputs, attention_cat):
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_cat)
energies = self.v(
torch.tanh(processed_query + processed_attention_weights +
processed_inputs))
energies = energies.squeeze(-1)
return energies
def forward(self, attention_hidden_state, inputs, processed_inputs,
attention_cat, mask):
attention = self.get_attention(
attention_hidden_state, processed_inputs, attention_cat)
if mask is not None:
attention.data.masked_fill_(1 - mask, self._mask_value)
# Windowing
if not self.training and self.windowing:
back_win = self.win_idx - self.win_back
front_win = self.win_idx + self.win_front
if back_win > 0:
attention[:, :back_win] = -float("inf")
if front_win < inputs.shape[1]:
attention[:, front_win:] = -float("inf")
# Update the window
self.win_idx = torch.argmax(attention, 1).long()[0].item()
alignment = torch.sigmoid(attention) / torch.sigmoid(
attention).sum(dim=1).unsqueeze(1)
context = torch.bmm(alignment.unsqueeze(1), inputs)
context = context.squeeze(1)
return context, alignment
class Postnet(nn.Module):
def __init__(self, mel_dim, num_convs=5):
super(Postnet, self).__init__()
self.convolutions = nn.ModuleList()
self.convolutions.append(
ConvBNBlock(mel_dim, 512, kernel_size=5, nonlinear='tanh'))
for i in range(1, num_convs - 1):
self.convolutions.append(
ConvBNBlock(512, 512, kernel_size=5, nonlinear='tanh'))
self.convolutions.append(
ConvBNBlock(512, mel_dim, kernel_size=5, nonlinear=None))
def forward(self, x):
for layer in self.convolutions:
x = layer(x)
return x
class Encoder(nn.Module):
def __init__(self, in_features=512):
super(Encoder, self).__init__()
convolutions = []
for _ in range(3):
convolutions.append(
ConvBNBlock(in_features, in_features, 5, 'relu'))
self.convolutions = nn.Sequential(*convolutions)
self.lstm = nn.LSTM(
in_features,
int(in_features / 2),
num_layers=1,
batch_first=True,
bidirectional=True)
def forward(self, x, input_lengths):
x = self.convolutions(x)
x = x.transpose(1, 2)
input_lengths = input_lengths.cpu().numpy()
x = nn.utils.rnn.pack_padded_sequence(
x, input_lengths, batch_first=True)
self.lstm.flatten_parameters()
outputs, _ = self.lstm(x)
outputs, _ = nn.utils.rnn.pad_packed_sequence(
outputs,
batch_first=True,
)
return outputs
def inference(self, x):
x = self.convolutions(x)
x = x.transpose(1, 2)
self.lstm.flatten_parameters()
outputs, _ = self.lstm(x)
return outputs
# adapted from https://github.com/NVIDIA/tacotron2/
class Decoder(nn.Module):
def __init__(self, in_features, inputs_dim, r, attn_win):
super(Decoder, self).__init__()
self.mel_channels = inputs_dim
self.r = r
self.encoder_embedding_dim = in_features
self.attention_rnn_dim = 1024
self.decoder_rnn_dim = 1024
self.prenet_dim = 256
self.max_decoder_steps = 1000
self.gate_threshold = 0.5
self.p_attention_dropout = 0.1
self.p_decoder_dropout = 0.1
self.prenet = Prenet(self.mel_channels * r,
[self.prenet_dim, self.prenet_dim])
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
self.attention_rnn_dim)
self.attention_layer = Attention(self.attention_rnn_dim, in_features,
128, 32, 31, attn_win)
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
self.decoder_rnn_dim, 1)
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
self.mel_channels * r)
self.stopnet = nn.Sequential(
nn.Dropout(0.1),
Linear(self.decoder_rnn_dim + self.mel_channels * r,
1,
bias=True,
init_gain='sigmoid'))
self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim)
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
def get_go_frame(self, inputs):
B = inputs.size(0)
memory = self.go_frame_init(inputs.data.new_zeros(B).long())
return memory
def _init_states(self, inputs, mask):
B = inputs.size(0)
T = inputs.size(1)
self.attention_hidden = self.attention_rnn_init(
inputs.data.new_zeros(B).long())
self.attention_cell = Variable(
inputs.data.new(B, self.attention_rnn_dim).zero_())
self.decoder_hidden = self.decoder_rnn_inits(
inputs.data.new_zeros(B).long())
self.decoder_cell = Variable(
inputs.data.new(B, self.decoder_rnn_dim).zero_())
self.attention_weights = Variable(inputs.data.new(B, T).zero_())
self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_())
self.context = Variable(
inputs.data.new(B, self.encoder_embedding_dim).zero_())
self.inputs = inputs
self.processed_inputs = self.attention_layer.inputs_layer(inputs)
self.mask = mask
def _reshape_memory(self, memories):
memories = memories.view(
memories.size(0), int(memories.size(1) / self.r), -1)
memories = memories.transpose(0, 1)
return memories
def _parse_outputs(self, outputs, gate_outputs, alignments):
alignments = torch.stack(alignments).transpose(0, 1)
gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
gate_outputs = gate_outputs.contiguous()
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
outputs = outputs.view(
outputs.size(0), -1, self.mel_channels)
outputs = outputs.transpose(1, 2)
return outputs, gate_outputs, alignments
def decode(self, memory):
cell_input = torch.cat((memory, self.context), -1)
self.attention_hidden, self.attention_cell = self.attention_rnn(
cell_input, (self.attention_hidden, self.attention_cell))
self.attention_hidden = F.dropout(
self.attention_hidden, self.p_attention_dropout, self.training)
self.attention_cell = F.dropout(
self.attention_cell, self.p_attention_dropout, self.training)
attention_cat = torch.cat((self.attention_weights.unsqueeze(1),
self.attention_weights_cum.unsqueeze(1)),
dim=1)
self.context, self.attention_weights = self.attention_layer(
self.attention_hidden, self.inputs, self.processed_inputs,
attention_cat, self.mask)
self.attention_weights_cum += self.attention_weights
memory = torch.cat(
(self.attention_hidden, self.context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
memory, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(self.decoder_hidden,
self.p_decoder_dropout, self.training)
self.decoder_cell = F.dropout(self.decoder_cell,
self.p_decoder_dropout, self.training)
decoder_hidden_context = torch.cat(
(self.decoder_hidden, self.context), dim=1)
decoder_output = self.linear_projection(
decoder_hidden_context)
stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1)
gate_prediction = self.stopnet(stopnet_input)
return decoder_output, gate_prediction, self.attention_weights
def forward(self, inputs, memories, mask):
memory = self.get_go_frame(inputs).unsqueeze(0)
memories = self._reshape_memory(memories)
memories = torch.cat((memory, memories), dim=0)
memories = self.prenet(memories)
self._init_states(inputs, mask=mask)
outputs, gate_outputs, alignments = [], [], []
while len(outputs) < memories.size(0) - 1:
memory = memories[len(outputs)]
mel_output, gate_output, attention_weights = self.decode(
memory)
outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output.squeeze(1)]
alignments += [attention_weights]
outputs, gate_outputs, alignments = self._parse_outputs(
outputs, gate_outputs, alignments)
return outputs, gate_outputs, alignments
def inference(self, inputs):
memory = self.get_go_frame(inputs)
self._init_states(inputs, mask=None)
self.attention_layer.init_win_idx()
outputs, gate_outputs, alignments, t = [], [], [], 0
stop_flags = [False, False]
while True:
memory = self.prenet(memory)
mel_output, gate_output, alignment = self.decode(memory)
gate_output = torch.sigmoid(gate_output.data)
outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output]
alignments += [alignment]
stop_flags[0] = stop_flags[0] or gate_output > 0.5
stop_flags[1] = stop_flags[1] or alignment[0, -3:].sum() > 0.5
if all(stop_flags):
break
elif len(outputs) == self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")
break
memory = mel_output
t += 1
outputs, gate_outputs, alignments = self._parse_outputs(
outputs, gate_outputs, alignments)
return outputs, gate_outputs, alignments
def inference_step(self, inputs, t, memory=None):
"""
For debug purposes
"""
if t == 0:
memory = self.get_go_frame(inputs)
self._init_states(inputs, mask=None)
memory = self.prenet(memory)
mel_output, gate_output, alignment = self.decode(memory)
gate_output = torch.sigmoid(gate_output.data)
memory = mel_output
return mel_output, gate_output, alignment

View File

@ -3,6 +3,7 @@ import torch
from torch import nn from torch import nn
from math import sqrt from math import sqrt
from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG
from utils.generic_utils import sequence_mask
class Tacotron(nn.Module): class Tacotron(nn.Module):
@ -27,15 +28,13 @@ class Tacotron(nn.Module):
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim), nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
nn.Sigmoid()) nn.Sigmoid())
def forward(self, characters, mel_specs=None, mask=None): def forward(self, characters, text_lengths, mel_specs=None):
B = characters.size(0) B = characters.size(0)
mask = sequence_mask(text_lengths).to(characters.device)
inputs = self.embedding(characters) inputs = self.embedding(characters)
# batch x time x dim
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
# batch x time x dim*r
mel_outputs, alignments, stop_tokens = self.decoder( mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, mask) encoder_outputs, mel_specs, mask)
# batch x time x dim
mel_outputs = mel_outputs.view(B, -1, self.mel_dim) mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
linear_outputs = self.postnet(mel_outputs) linear_outputs = self.postnet(mel_outputs)
linear_outputs = self.last_linear(linear_outputs) linear_outputs = self.last_linear(linear_outputs)
@ -44,12 +43,9 @@ class Tacotron(nn.Module):
def inference(self, characters): def inference(self, characters):
B = characters.size(0) B = characters.size(0)
inputs = self.embedding(characters) inputs = self.embedding(characters)
# batch x time x dim
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
# batch x time x dim*r
mel_outputs, alignments, stop_tokens = self.decoder.inference( mel_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs) encoder_outputs)
# batch x time x dim
mel_outputs = mel_outputs.view(B, -1, self.mel_dim) mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
linear_outputs = self.postnet(mel_outputs) linear_outputs = self.postnet(mel_outputs)
linear_outputs = self.last_linear(linear_outputs) linear_outputs = self.last_linear(linear_outputs)

51
models/tacotron2.py Normal file
View File

@ -0,0 +1,51 @@
from math import sqrt
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
from layers.tacotron2 import Encoder, Decoder, Postnet
from utils.generic_utils import sequence_mask
# TODO: match function arguments with tacotron
class Tacotron2(nn.Module):
def __init__(self, num_chars, r, attn_win=False):
super(Tacotron2, self).__init__()
self.n_mel_channels = 80
self.n_frames_per_step = r
self.embedding = nn.Embedding(num_chars, 512)
std = sqrt(2.0 / (num_chars + 512))
val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val)
self.encoder = Encoder(512)
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win)
self.postnet = Postnet(self.n_mel_channels)
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):
mel_outputs = mel_outputs.transpose(1, 2)
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
return mel_outputs, mel_outputs_postnet, alignments
def forward(self, text, text_lengths, mel_specs=None):
# compute mask for padding
mask = sequence_mask(text_lengths).to(characters.device)
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
mel_outputs, stop_tokens, alignments = self.decoder(
encoder_outputs, mel_specs, mask)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
def inference(self, text):
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs)
mel_outputs, stop_tokens, alignments = self.decoder.inference(
encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens

69
tests/tacotron2_tests.py Normal file
View File

@ -0,0 +1,69 @@
import os
import copy
import torch
import unittest
import numpy as np
from torch import optim
from torch import nn
from utils.generic_utils import load_config
from layers.losses import MSELossMasked
from models.tacotron2 import Tacotron2
torch.manual_seed(1)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
file_path = os.path.dirname(os.path.realpath(__file__))
c = load_config(os.path.join(file_path, 'test_config.json'))
class TacotronTrainTest(unittest.TestCase):
def test_train_step(self):
input = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 128, (8, )).long().to(device)
input_lengths = torch.sort(input_lengths, descending=True)[0]
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
stop_targets = torch.zeros(8, 30, 1).float().to(device)
for idx in mel_lengths:
stop_targets[:, int(idx.item()):, 0] = 1.0
stop_targets = stop_targets.view(input.shape[0],
stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
criterion = MSELossMasked().to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device)
model = Tacotron2(24, c.r).to(device)
model.train()
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(),
model_ref.parameters()):
assert (param - param_ref).sum() == 0, param
count += 1
optimizer = optim.Adam(model.parameters(), lr=c.lr)
for i in range(5):
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
input, input_lengths, mel_spec)
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
optimizer.zero_grad()
loss = criterion(mel_out, mel_spec, mel_lengths)
stop_loss = criterion_st(stop_tokens, stop_targets)
loss = loss + criterion(mel_postnet_out, mel_postnet_spec, mel_lengths) + stop_loss
loss.backward()
optimizer.step()
# check parameter changes
count = 0
for param, param_ref in zip(model.parameters(),
model_ref.parameters()):
# ignore pre-higway layer since it works conditional
# if count not in [145, 59]:
assert (param != param_ref).any(
), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref)
count += 1