Set tacotron model parameters to adap to common_layers.py - Prenet and Attention

This commit is contained in:
Eren Golge 2019-05-27 14:40:28 +02:00
parent 2586be7d33
commit ba492f43be
5 changed files with 158 additions and 46 deletions

83
layers/common_layers.py Normal file
View File

@ -0,0 +1,83 @@
from math import sqrt
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
class Linear(nn.Module):
def __init__(self,
in_features,
out_features,
bias=True,
init_gain='linear'):
super(Linear, self).__init__()
self.linear_layer = torch.nn.Linear(
in_features, out_features, bias=bias)
self._init_w(init_gain)
def _init_w(self, init_gain):
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(init_gain))
def forward(self, x):
return self.linear_layer(x)
class LinearBN(nn.Module):
def __init__(self,
in_features,
out_features,
bias=True,
init_gain='linear'):
super(LinearBN, self).__init__()
self.linear_layer = torch.nn.Linear(
in_features, out_features, bias=bias)
self.bn = nn.BatchNorm1d(out_features)
self._init_w(init_gain)
def _init_w(self, init_gain):
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(init_gain))
def forward(self, x):
out = self.linear_layer(x)
if len(out.shape) == 3:
out = out.permute(1, 2, 0)
out = self.bn(out)
if len(out.shape) == 3:
out = out.permute(2, 0, 1)
return out
class Prenet(nn.Module):
def __init__(self,
in_features,
prenet_type="original",
prenet_dropout=True,
out_features=[256, 256],
bias=True):
super(Prenet, self).__init__()
self.prenet_type = prenet_type
self.prenet_dropout = prenet_dropout
in_features = [in_features] + out_features[:-1]
if prenet_type == "bn":
self.layers = nn.ModuleList([
LinearBN(in_size, out_size, bias=bias)
for (in_size, out_size) in zip(in_features, out_features)
])
elif prenet_type == "original":
self.layers = nn.ModuleList([
Linear(in_size, out_size, bias=bias)
for (in_size, out_size) in zip(in_features, out_features)
])
def forward(self, x):
for linear in self.layers:
if self.prenet_dropout:
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
else:
x = F.relu(linear(x))
return x

View File

@ -2,38 +2,39 @@
import torch import torch
from torch import nn from torch import nn
from .attention import AttentionRNNCell from .attention import AttentionRNNCell
from .common_layers import Prenet
class Prenet(nn.Module): # class Prenet(nn.Module):
r""" Prenet as explained at https://arxiv.org/abs/1703.10135. # r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
It creates as many layers as given by 'out_features' # It creates as many layers as given by 'out_features'
Args: # Args:
in_features (int): size of the input vector # in_features (int): size of the input vector
out_features (int or list): size of each output sample. # out_features (int or list): size of each output sample.
If it is a list, for each value, there is created a new layer. # If it is a list, for each value, there is created a new layer.
""" # """
def __init__(self, in_features, out_features=[256, 128]): # def __init__(self, in_features, out_features=[256, 128]):
super(Prenet, self).__init__() # super(Prenet, self).__init__()
in_features = [in_features] + out_features[:-1] # in_features = [in_features] + out_features[:-1]
self.layers = nn.ModuleList([ # self.layers = nn.ModuleList([
nn.Linear(in_size, out_size) # nn.Linear(in_size, out_size)
for (in_size, out_size) in zip(in_features, out_features) # for (in_size, out_size) in zip(in_features, out_features)
]) # ])
self.relu = nn.ReLU() # self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5) # self.dropout = nn.Dropout(0.5)
# self.init_layers() # # self.init_layers()
def init_layers(self): # def init_layers(self):
for layer in self.layers: # for layer in self.layers:
torch.nn.init.xavier_uniform_( # torch.nn.init.xavier_uniform_(
layer.weight, gain=torch.nn.init.calculate_gain('relu')) # layer.weight, gain=torch.nn.init.calculate_gain('relu'))
def forward(self, inputs): # def forward(self, inputs):
for linear in self.layers: # for linear in self.layers:
inputs = self.dropout(self.relu(linear(inputs))) # inputs = self.dropout(self.relu(linear(inputs)))
return inputs # return inputs
class BatchNormConv1d(nn.Module): class BatchNormConv1d(nn.Module):
@ -301,8 +302,9 @@ class Decoder(nn.Module):
memory_size (int): size of the past window. if <= 0 memory_size = r memory_size (int): size of the past window. if <= 0 memory_size = r
""" """
def __init__(self, in_features, memory_dim, r, memory_size, def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing,
attn_windowing, attn_norm, separate_stopnet): attn_norm, prenet_type, prenet_dropout, forward_attn,
trans_agent, location_attn, separate_stopnet):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.r = r self.r = r
self.in_features = in_features self.in_features = in_features
@ -312,7 +314,10 @@ class Decoder(nn.Module):
self.separate_stopnet = separate_stopnet self.separate_stopnet = separate_stopnet
# memory -> |Prenet| -> processed_memory # memory -> |Prenet| -> processed_memory
self.prenet = Prenet( self.prenet = Prenet(
memory_dim * self.memory_size, out_features=[256, 128]) memory_dim * self.memory_size,
prenet_type,
prenet_dropout,
out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
self.attention_rnn = AttentionRNNCell( self.attention_rnn = AttentionRNNCell(
out_dim=128, out_dim=128,
@ -385,23 +390,22 @@ class Decoder(nn.Module):
stop_tokens = torch.stack(stop_tokens).transpose(0, 1).squeeze(-1) stop_tokens = torch.stack(stop_tokens).transpose(0, 1).squeeze(-1)
return outputs, attentions, stop_tokens return outputs, attentions, stop_tokens
def decode(self, def decode(self, inputs, t, mask=None):
inputs,
t,
mask=None):
# Prenet # Prenet
processed_memory = self.prenet(self.memory_input) processed_memory = self.prenet(self.memory_input)
# Attention RNN # Attention RNN
attention_cat = torch.cat( attention_cat = torch.cat(
(self.attention.unsqueeze(1), self.attention_cum.unsqueeze(1)), dim=1) (self.attention.unsqueeze(1), self.attention_cum.unsqueeze(1)),
dim=1)
self.attention_rnn_hidden, self.current_context_vec, self.attention = self.attention_rnn( self.attention_rnn_hidden, self.current_context_vec, self.attention = self.attention_rnn(
processed_memory, self.current_context_vec, self.attention_rnn_hidden, processed_memory, self.current_context_vec,
inputs, attention_cat, mask, t) self.attention_rnn_hidden, inputs, attention_cat, mask, t)
del attention_cat del attention_cat
self.attention_cum += self.attention self.attention_cum += self.attention
# Concat RNN output and attention context vector # Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in( decoder_input = self.project_to_decoder_in(
torch.cat((self.attention_rnn_hidden, self.current_context_vec), -1)) torch.cat((self.attention_rnn_hidden, self.current_context_vec),
-1))
# Pass through the decoder RNNs # Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)): for idx in range(len(self.decoder_rnns)):
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx]( self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
@ -427,7 +431,8 @@ class Decoder(nn.Module):
self.memory_input = torch.cat([ self.memory_input = torch.cat([
self.memory_input[:, self.r * self.memory_dim:].clone(), self.memory_input[:, self.r * self.memory_dim:].clone(),
new_memory new_memory
], dim=-1) ],
dim=-1)
else: else:
self.memory_input = new_memory self.memory_input = new_memory

View File

@ -9,23 +9,29 @@ from utils.generic_utils import sequence_mask
class Tacotron(nn.Module): class Tacotron(nn.Module):
def __init__(self, def __init__(self,
num_chars, num_chars,
r=5,
linear_dim=1025, linear_dim=1025,
mel_dim=80, mel_dim=80,
r=5,
padding_idx=None,
memory_size=5, memory_size=5,
attn_win=False, attn_win=False,
attn_norm="sigmoid", attn_norm="sigmoid",
prenet_type="original",
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
location_attn=True,
separate_stopnet=True): separate_stopnet=True):
super(Tacotron, self).__init__() super(Tacotron, self).__init__()
self.r = r self.r = r
self.mel_dim = mel_dim self.mel_dim = mel_dim
self.linear_dim = linear_dim self.linear_dim = linear_dim
self.embedding = nn.Embedding(num_chars, 256, padding_idx=padding_idx) self.embedding = nn.Embedding(num_chars, 256)
self.embedding.weight.data.normal_(0, 0.3) self.embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(256) self.encoder = Encoder(256)
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win, self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win,
attn_norm, separate_stopnet) attn_norm, prenet_type, prenet_dropout,
forward_attn, trans_agent, location_attn,
separate_stopnet)
self.postnet = PostCBHG(mel_dim) self.postnet = PostCBHG(mel_dim)
self.last_linear = nn.Sequential( self.last_linear = nn.Sequential(
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim), nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),

View File

@ -18,6 +18,11 @@ file_path = os.path.dirname(os.path.realpath(__file__))
c = load_config(os.path.join(file_path, 'test_config.json')) c = load_config(os.path.join(file_path, 'test_config.json'))
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class TacotronTrainTest(unittest.TestCase): class TacotronTrainTest(unittest.TestCase):
def test_train_step(self): def test_train_step(self):
input = torch.randint(0, 24, (8, 128)).long().to(device) input = torch.randint(0, 24, (8, 128)).long().to(device)
@ -33,13 +38,19 @@ class TacotronTrainTest(unittest.TestCase):
stop_targets = stop_targets.view(input.shape[0], stop_targets = stop_targets.view(input.shape[0],
stop_targets.size(1) // c.r, -1) stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() stop_targets = (stop_targets.sum(2) >
0.0).unsqueeze(2).float().squeeze()
criterion = L1LossMasked().to(device) criterion = L1LossMasked().to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device) criterion_st = nn.BCEWithLogitsLoss().to(device)
model = Tacotron(32, c.audio['num_freq'], c.audio['num_mels'], model = Tacotron(
c.r, memory_size=c.memory_size).to(device) 32,
linear_dim=c.audio['num_freq'],
mel_dim=c.audio['num_mels'],
r=c.r,
memory_size=c.memory_size).to(device)
model.train() model.train()
print(" > Num parameters for Tacotron model:%s"%(count_parameters(model)))
model_ref = copy.deepcopy(model) model_ref = copy.deepcopy(model)
count = 0 count = 0
for param, param_ref in zip(model.parameters(), for param, param_ref in zip(model.parameters(),

View File

@ -251,9 +251,16 @@ def setup_model(num_chars, c):
model = MyModel( model = MyModel(
num_chars=num_chars, num_chars=num_chars,
r=c.r, r=c.r,
linear_dim=1025,
mel_dim=80,
memory_size=c.memory_size,
attn_win=c.windowing, attn_win=c.windowing,
attn_norm=c.attention_norm, attn_norm=c.attention_norm,
memory_size=c.memory_size, prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
location_attn=c.location_attn,
separate_stopnet=c.separate_stopnet) separate_stopnet=c.separate_stopnet)
elif c.model.lower() == "tacotron2": elif c.model.lower() == "tacotron2":
model = MyModel( model = MyModel(