renaming layers to be converted to TF counterpart

This commit is contained in:
erogol 2020-04-28 18:16:02 +02:00
parent bee288fa93
commit d282222553
3 changed files with 80 additions and 84 deletions

View File

@ -33,7 +33,7 @@ class LinearBN(nn.Module):
super(LinearBN, self).__init__() super(LinearBN, self).__init__()
self.linear_layer = torch.nn.Linear( self.linear_layer = torch.nn.Linear(
in_features, out_features, bias=bias) in_features, out_features, bias=bias)
self.bn = nn.BatchNorm1d(out_features) self.batch_normalization = nn.BatchNorm1d(out_features)
self._init_w(init_gain) self._init_w(init_gain)
def _init_w(self, init_gain): def _init_w(self, init_gain):
@ -45,7 +45,7 @@ class LinearBN(nn.Module):
out = self.linear_layer(x) out = self.linear_layer(x)
if len(out.shape) == 3: if len(out.shape) == 3:
out = out.permute(1, 2, 0) out = out.permute(1, 2, 0)
out = self.bn(out) out = self.batch_normalization(out)
if len(out.shape) == 3: if len(out.shape) == 3:
out = out.permute(2, 0, 1) out = out.permute(2, 0, 1)
return out return out
@ -63,18 +63,18 @@ class Prenet(nn.Module):
self.prenet_dropout = prenet_dropout self.prenet_dropout = prenet_dropout
in_features = [in_features] + out_features[:-1] in_features = [in_features] + out_features[:-1]
if prenet_type == "bn": if prenet_type == "bn":
self.layers = nn.ModuleList([ self.linear_layers = nn.ModuleList([
LinearBN(in_size, out_size, bias=bias) LinearBN(in_size, out_size, bias=bias)
for (in_size, out_size) in zip(in_features, out_features) for (in_size, out_size) in zip(in_features, out_features)
]) ])
elif prenet_type == "original": elif prenet_type == "original":
self.layers = nn.ModuleList([ self.linear_layers = nn.ModuleList([
Linear(in_size, out_size, bias=bias) Linear(in_size, out_size, bias=bias)
for (in_size, out_size) in zip(in_features, out_features) for (in_size, out_size) in zip(in_features, out_features)
]) ])
def forward(self, x): def forward(self, x):
for linear in self.layers: for linear in self.linear_layers:
if self.prenet_dropout: if self.prenet_dropout:
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training) x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
else: else:
@ -93,7 +93,7 @@ class LocationLayer(nn.Module):
attention_n_filters=32, attention_n_filters=32,
attention_kernel_size=31): attention_kernel_size=31):
super(LocationLayer, self).__init__() super(LocationLayer, self).__init__()
self.location_conv = nn.Conv1d( self.location_conv1d = nn.Conv1d(
in_channels=2, in_channels=2,
out_channels=attention_n_filters, out_channels=attention_n_filters,
kernel_size=attention_kernel_size, kernel_size=attention_kernel_size,
@ -104,7 +104,7 @@ class LocationLayer(nn.Module):
attention_n_filters, attention_dim, bias=False, init_gain='tanh') attention_n_filters, attention_dim, bias=False, init_gain='tanh')
def forward(self, attention_cat): def forward(self, attention_cat):
processed_attention = self.location_conv(attention_cat) processed_attention = self.location_conv1d(attention_cat)
processed_attention = self.location_dense( processed_attention = self.location_dense(
processed_attention.transpose(1, 2)) processed_attention.transpose(1, 2))
return processed_attention return processed_attention

View File

@ -6,130 +6,126 @@ from .common_layers import init_attn, Prenet, Linear
class ConvBNBlock(nn.Module): class ConvBNBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, nonlinear=None): def __init__(self, in_channels, out_channels, kernel_size, activation=None):
super(ConvBNBlock, self).__init__() super(ConvBNBlock, self).__init__()
assert (kernel_size - 1) % 2 == 0 assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2 padding = (kernel_size - 1) // 2
conv1d = nn.Conv1d(in_channels, self.convolution1d = nn.Conv1d(in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
padding=padding) padding=padding)
norm = nn.BatchNorm1d(out_channels) self.batch_normalization = nn.BatchNorm1d(out_channels)
dropout = nn.Dropout(p=0.5) self.dropout = nn.Dropout(p=0.5)
if nonlinear == 'relu': if activation == 'relu':
self.net = nn.Sequential(conv1d, norm, nn.ReLU(), dropout) self.activation = nn.ReLU()
elif nonlinear == 'tanh': elif activation == 'tanh':
self.net = nn.Sequential(conv1d, norm, nn.Tanh(), dropout) self.activation = nn.Tanh()
else: else:
self.net = nn.Sequential(conv1d, norm, dropout) self.activation = nn.Identity()
def forward(self, x): def forward(self, x):
output = self.net(x) o = self.convolution1d(x)
return output o = self.batch_normalization(o)
o = self.activation(o)
o = self.dropout(o)
return o
class Postnet(nn.Module): class Postnet(nn.Module):
def __init__(self, mel_dim, num_convs=5): def __init__(self, output_dim, num_convs=5):
super(Postnet, self).__init__() super(Postnet, self).__init__()
self.convolutions = nn.ModuleList() self.convolutions = nn.ModuleList()
self.convolutions.append( self.convolutions.append(
ConvBNBlock(mel_dim, 512, kernel_size=5, nonlinear='tanh')) ConvBNBlock(output_dim, 512, kernel_size=5, activation='tanh'))
for _ in range(1, num_convs - 1): for _ in range(1, num_convs - 1):
self.convolutions.append( self.convolutions.append(
ConvBNBlock(512, 512, kernel_size=5, nonlinear='tanh')) ConvBNBlock(512, 512, kernel_size=5, activation='tanh'))
self.convolutions.append( self.convolutions.append(
ConvBNBlock(512, mel_dim, kernel_size=5, nonlinear=None)) ConvBNBlock(512, output_dim, kernel_size=5, activation=None))
def forward(self, x): def forward(self, x):
o = x
for layer in self.convolutions: for layer in self.convolutions:
x = layer(x) o = layer(o)
return x return o
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, in_features=512): def __init__(self, output_input_dim=512):
super(Encoder, self).__init__() super(Encoder, self).__init__()
convolutions = [] self.convolutions = nn.ModuleList()
for _ in range(3): for _ in range(3):
convolutions.append( self.convolutions.append(
ConvBNBlock(in_features, in_features, 5, 'relu')) ConvBNBlock(output_input_dim, output_input_dim, 5, 'relu'))
self.convolutions = nn.Sequential(*convolutions) self.lstm = nn.LSTM(output_input_dim,
self.lstm = nn.LSTM(in_features, int(output_input_dim / 2),
int(in_features / 2),
num_layers=1, num_layers=1,
batch_first=True, batch_first=True,
bidirectional=True) bidirectional=True)
self.rnn_state = None self.rnn_state = None
def forward(self, x, input_lengths): def forward(self, x, input_lengths):
x = self.convolutions(x) o = x
x = x.transpose(1, 2) for layer in self.convolutions:
x = nn.utils.rnn.pack_padded_sequence(x, o = layer(o)
o = o.transpose(1, 2)
o = nn.utils.rnn.pack_padded_sequence(o,
input_lengths, input_lengths,
batch_first=True) batch_first=True)
self.lstm.flatten_parameters() self.lstm.flatten_parameters()
outputs, _ = self.lstm(x) o, _ = self.lstm(o)
outputs, _ = nn.utils.rnn.pad_packed_sequence( o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True)
outputs, return o
batch_first=True,
)
return outputs
def inference(self, x): def inference(self, x):
x = self.convolutions(x) o = x
x = x.transpose(1, 2) for layer in self.convolutions:
o = layer(o)
o = x.transpose(1, 2)
self.lstm.flatten_parameters() self.lstm.flatten_parameters()
outputs, _ = self.lstm(x) o, _ = self.lstm(o)
return outputs return o
def inference_truncated(self, x):
"""
Preserve encoder state for continuous inference
"""
x = self.convolutions(x)
x = x.transpose(1, 2)
self.lstm.flatten_parameters()
outputs, self.rnn_state = self.lstm(x, self.rnn_state)
return outputs
# adapted from https://github.com/NVIDIA/tacotron2/ # adapted from https://github.com/NVIDIA/tacotron2/
class Decoder(nn.Module): class Decoder(nn.Module):
# Pylint gets confused by PyTorch conventions here # Pylint gets confused by PyTorch conventions here
#pylint: disable=attribute-defined-outside-init #pylint: disable=attribute-defined-outside-init
def __init__(self, in_features, memory_dim, r, attn_type, attn_win, attn_norm, def __init__(self, input_dim, frame_dim, r, attn_type, attn_win, attn_norm,
prenet_type, prenet_dropout, forward_attn, trans_agent, prenet_type, prenet_dropout, forward_attn, trans_agent,
forward_attn_mask, location_attn, attn_K, separate_stopnet, forward_attn_mask, location_attn, attn_K, separate_stopnet,
speaker_embedding_dim): speaker_embedding_dim):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.memory_dim = memory_dim self.frame_dim = frame_dim
self.r_init = r self.r_init = r
self.r = r self.r = r
self.encoder_embedding_dim = in_features self.encoder_embedding_dim = input_dim
self.separate_stopnet = separate_stopnet self.separate_stopnet = separate_stopnet
self.max_decoder_steps = 1000
self.gate_threshold = 0.5
# model dimensions
self.query_dim = 1024 self.query_dim = 1024
self.decoder_rnn_dim = 1024 self.decoder_rnn_dim = 1024
self.prenet_dim = 256 self.prenet_dim = 256
self.max_decoder_steps = 1000 self.attn_dim = 128
self.gate_threshold = 0.5
self.p_attention_dropout = 0.1 self.p_attention_dropout = 0.1
self.p_decoder_dropout = 0.1 self.p_decoder_dropout = 0.1
# memory -> |Prenet| -> processed_memory # memory -> |Prenet| -> processed_memory
prenet_dim = self.memory_dim prenet_dim = self.frame_dim
self.prenet = Prenet( self.prenet = Prenet(prenet_dim,
prenet_dim, prenet_type,
prenet_type, prenet_dropout,
prenet_dropout, out_features=[self.prenet_dim, self.prenet_dim],
out_features=[self.prenet_dim, self.prenet_dim], bias=False)
bias=False)
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, self.attention_rnn = nn.LSTMCell(self.prenet_dim + input_dim,
self.query_dim) self.query_dim)
self.attention = init_attn(attn_type=attn_type, self.attention = init_attn(attn_type=attn_type,
query_dim=self.query_dim, query_dim=self.query_dim,
embedding_dim=in_features, embedding_dim=input_dim,
attention_dim=128, attention_dim=128,
location_attention=location_attn, location_attention=location_attn,
attention_location_n_filters=32, attention_location_n_filters=32,
@ -141,15 +137,15 @@ class Decoder(nn.Module):
forward_attn_mask=forward_attn_mask, forward_attn_mask=forward_attn_mask,
attn_K=attn_K) attn_K=attn_K)
self.decoder_rnn = nn.LSTMCell(self.query_dim + in_features, self.decoder_rnn = nn.LSTMCell(self.query_dim + input_dim,
self.decoder_rnn_dim, 1) self.decoder_rnn_dim, 1)
self.linear_projection = Linear(self.decoder_rnn_dim + in_features, self.linear_projection = Linear(self.decoder_rnn_dim + input_dim,
self.memory_dim * self.r_init) self.frame_dim * self.r_init)
self.stopnet = nn.Sequential( self.stopnet = nn.Sequential(
nn.Dropout(0.1), nn.Dropout(0.1),
Linear(self.decoder_rnn_dim + self.memory_dim * self.r_init, Linear(self.decoder_rnn_dim + self.frame_dim * self.r_init,
1, 1,
bias=True, bias=True,
init_gain='sigmoid')) init_gain='sigmoid'))
@ -161,7 +157,7 @@ class Decoder(nn.Module):
def get_go_frame(self, inputs): def get_go_frame(self, inputs):
B = inputs.size(0) B = inputs.size(0)
memory = torch.zeros(1, device=inputs.device).repeat(B, memory = torch.zeros(1, device=inputs.device).repeat(B,
self.memory_dim * self.r) self.frame_dim * self.r)
return memory return memory
def _init_states(self, inputs, mask, keep_states=False): def _init_states(self, inputs, mask, keep_states=False):
@ -187,9 +183,9 @@ class Decoder(nn.Module):
Reshape the spectrograms for given 'r' Reshape the spectrograms for given 'r'
""" """
# Grouping multiple frames if necessary # Grouping multiple frames if necessary
if memory.size(-1) == self.memory_dim: if memory.size(-1) == self.frame_dim:
memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1) memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1)
# Time first (T_decoder, B, memory_dim) # Time first (T_decoder, B, frame_dim)
memory = memory.transpose(0, 1) memory = memory.transpose(0, 1)
return memory return memory
@ -197,22 +193,22 @@ class Decoder(nn.Module):
alignments = torch.stack(alignments).transpose(0, 1) alignments = torch.stack(alignments).transpose(0, 1)
stop_tokens = torch.stack(stop_tokens).transpose(0, 1) stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous() outputs = torch.stack(outputs).transpose(0, 1).contiguous()
outputs = outputs.view(outputs.size(0), -1, self.memory_dim) outputs = outputs.view(outputs.size(0), -1, self.frame_dim)
outputs = outputs.transpose(1, 2) outputs = outputs.transpose(1, 2)
return outputs, stop_tokens, alignments return outputs, stop_tokens, alignments
def _update_memory(self, memory): def _update_memory(self, memory):
if len(memory.shape) == 2: if len(memory.shape) == 2:
return memory[:, self.memory_dim * (self.r - 1):] return memory[:, self.frame_dim * (self.r - 1):]
return memory[:, :, self.memory_dim * (self.r - 1):] return memory[:, :, self.frame_dim * (self.r - 1):]
def decode(self, memory): def decode(self, memory):
''' '''
shapes: shapes:
- memory: B x r * self.memory_dim - memory: B x r * self.frame_dim
''' '''
# self.context: B x D_en # self.context: B x D_en
# query_input: B x D_en + (r * self.memory_dim) # query_input: B x D_en + (r * self.frame_dim)
query_input = torch.cat((memory, self.context), -1) query_input = torch.cat((memory, self.context), -1)
# self.query and self.attention_rnn_cell_state : B x D_attn_rnn # self.query and self.attention_rnn_cell_state : B x D_attn_rnn
self.query, self.attention_rnn_cell_state = self.attention_rnn( self.query, self.attention_rnn_cell_state = self.attention_rnn(
@ -235,16 +231,16 @@ class Decoder(nn.Module):
# B x (D_decoder_rnn + D_en) # B x (D_decoder_rnn + D_en)
decoder_hidden_context = torch.cat((self.decoder_hidden, self.context), decoder_hidden_context = torch.cat((self.decoder_hidden, self.context),
dim=1) dim=1)
# B x (self.r * self.memory_dim) # B x (self.r * self.frame_dim)
decoder_output = self.linear_projection(decoder_hidden_context) decoder_output = self.linear_projection(decoder_hidden_context)
# B x (D_decoder_rnn + (self.r * self.memory_dim)) # B x (D_decoder_rnn + (self.r * self.frame_dim))
stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1) stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1)
if self.separate_stopnet: if self.separate_stopnet:
stop_token = self.stopnet(stopnet_input.detach()) stop_token = self.stopnet(stopnet_input.detach())
else: else:
stop_token = self.stopnet(stopnet_input) stop_token = self.stopnet(stopnet_input)
# select outputs for the reduction rate self.r # select outputs for the reduction rate self.r
decoder_output = decoder_output[:, :self.r * self.memory_dim] decoder_output = decoder_output[:, :self.r * self.frame_dim]
return decoder_output, self.attention.attention_weights, stop_token return decoder_output, self.attention.attention_weights, stop_token
def forward(self, inputs, memories, mask, speaker_embeddings=None): def forward(self, inputs, memories, mask, speaker_embeddings=None):

View File

@ -29,7 +29,7 @@ class Tacotron2(nn.Module):
super(Tacotron2, self).__init__() super(Tacotron2, self).__init__()
self.postnet_output_dim = postnet_output_dim self.postnet_output_dim = postnet_output_dim
self.decoder_output_dim = decoder_output_dim self.decoder_output_dim = decoder_output_dim
self.n_frames_per_step = r self.r = r
self.bidirectional_decoder = bidirectional_decoder self.bidirectional_decoder = bidirectional_decoder
decoder_dim = 512 if num_speakers > 1 else 512 decoder_dim = 512 if num_speakers > 1 else 512
encoder_dim = 512 if num_speakers > 1 else 512 encoder_dim = 512 if num_speakers > 1 else 512