From d282222553c97821f1028495fb7161d2f60b491d Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 28 Apr 2020 18:16:02 +0200 Subject: [PATCH] renaming layers to be converted to TF counterpart --- layers/common_layers.py | 14 ++-- layers/tacotron2.py | 148 +++++++++++++++++++--------------------- models/tacotron2.py | 2 +- 3 files changed, 80 insertions(+), 84 deletions(-) diff --git a/layers/common_layers.py b/layers/common_layers.py index 8b7ed125..d2afe012 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -33,7 +33,7 @@ class LinearBN(nn.Module): super(LinearBN, self).__init__() self.linear_layer = torch.nn.Linear( in_features, out_features, bias=bias) - self.bn = nn.BatchNorm1d(out_features) + self.batch_normalization = nn.BatchNorm1d(out_features) self._init_w(init_gain) def _init_w(self, init_gain): @@ -45,7 +45,7 @@ class LinearBN(nn.Module): out = self.linear_layer(x) if len(out.shape) == 3: out = out.permute(1, 2, 0) - out = self.bn(out) + out = self.batch_normalization(out) if len(out.shape) == 3: out = out.permute(2, 0, 1) return out @@ -63,18 +63,18 @@ class Prenet(nn.Module): self.prenet_dropout = prenet_dropout in_features = [in_features] + out_features[:-1] if prenet_type == "bn": - self.layers = nn.ModuleList([ + self.linear_layers = nn.ModuleList([ LinearBN(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features) ]) elif prenet_type == "original": - self.layers = nn.ModuleList([ + self.linear_layers = nn.ModuleList([ Linear(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features) ]) def forward(self, x): - for linear in self.layers: + for linear in self.linear_layers: if self.prenet_dropout: x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training) else: @@ -93,7 +93,7 @@ class LocationLayer(nn.Module): attention_n_filters=32, attention_kernel_size=31): super(LocationLayer, self).__init__() - self.location_conv = nn.Conv1d( + self.location_conv1d = nn.Conv1d( in_channels=2, out_channels=attention_n_filters, kernel_size=attention_kernel_size, @@ -104,7 +104,7 @@ class LocationLayer(nn.Module): attention_n_filters, attention_dim, bias=False, init_gain='tanh') def forward(self, attention_cat): - processed_attention = self.location_conv(attention_cat) + processed_attention = self.location_conv1d(attention_cat) processed_attention = self.location_dense( processed_attention.transpose(1, 2)) return processed_attention diff --git a/layers/tacotron2.py b/layers/tacotron2.py index fa76a6b2..3e439b9b 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -6,130 +6,126 @@ from .common_layers import init_attn, Prenet, Linear class ConvBNBlock(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, nonlinear=None): + def __init__(self, in_channels, out_channels, kernel_size, activation=None): super(ConvBNBlock, self).__init__() assert (kernel_size - 1) % 2 == 0 padding = (kernel_size - 1) // 2 - conv1d = nn.Conv1d(in_channels, + self.convolution1d = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding) - norm = nn.BatchNorm1d(out_channels) - dropout = nn.Dropout(p=0.5) - if nonlinear == 'relu': - self.net = nn.Sequential(conv1d, norm, nn.ReLU(), dropout) - elif nonlinear == 'tanh': - self.net = nn.Sequential(conv1d, norm, nn.Tanh(), dropout) + self.batch_normalization = nn.BatchNorm1d(out_channels) + self.dropout = nn.Dropout(p=0.5) + if activation == 'relu': + self.activation = nn.ReLU() + elif activation == 'tanh': + self.activation = nn.Tanh() else: - self.net = nn.Sequential(conv1d, norm, dropout) + self.activation = nn.Identity() def forward(self, x): - output = self.net(x) - return output + o = self.convolution1d(x) + o = self.batch_normalization(o) + o = self.activation(o) + o = self.dropout(o) + return o class Postnet(nn.Module): - def __init__(self, mel_dim, num_convs=5): + def __init__(self, output_dim, num_convs=5): super(Postnet, self).__init__() self.convolutions = nn.ModuleList() self.convolutions.append( - ConvBNBlock(mel_dim, 512, kernel_size=5, nonlinear='tanh')) + ConvBNBlock(output_dim, 512, kernel_size=5, activation='tanh')) for _ in range(1, num_convs - 1): self.convolutions.append( - ConvBNBlock(512, 512, kernel_size=5, nonlinear='tanh')) + ConvBNBlock(512, 512, kernel_size=5, activation='tanh')) self.convolutions.append( - ConvBNBlock(512, mel_dim, kernel_size=5, nonlinear=None)) + ConvBNBlock(512, output_dim, kernel_size=5, activation=None)) def forward(self, x): + o = x for layer in self.convolutions: - x = layer(x) - return x + o = layer(o) + return o class Encoder(nn.Module): - def __init__(self, in_features=512): + def __init__(self, output_input_dim=512): super(Encoder, self).__init__() - convolutions = [] + self.convolutions = nn.ModuleList() for _ in range(3): - convolutions.append( - ConvBNBlock(in_features, in_features, 5, 'relu')) - self.convolutions = nn.Sequential(*convolutions) - self.lstm = nn.LSTM(in_features, - int(in_features / 2), + self.convolutions.append( + ConvBNBlock(output_input_dim, output_input_dim, 5, 'relu')) + self.lstm = nn.LSTM(output_input_dim, + int(output_input_dim / 2), num_layers=1, batch_first=True, bidirectional=True) self.rnn_state = None def forward(self, x, input_lengths): - x = self.convolutions(x) - x = x.transpose(1, 2) - x = nn.utils.rnn.pack_padded_sequence(x, + o = x + for layer in self.convolutions: + o = layer(o) + o = o.transpose(1, 2) + o = nn.utils.rnn.pack_padded_sequence(o, input_lengths, batch_first=True) self.lstm.flatten_parameters() - outputs, _ = self.lstm(x) - outputs, _ = nn.utils.rnn.pad_packed_sequence( - outputs, - batch_first=True, - ) - return outputs + o, _ = self.lstm(o) + o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True) + return o def inference(self, x): - x = self.convolutions(x) - x = x.transpose(1, 2) + o = x + for layer in self.convolutions: + o = layer(o) + o = x.transpose(1, 2) self.lstm.flatten_parameters() - outputs, _ = self.lstm(x) - return outputs - - def inference_truncated(self, x): - """ - Preserve encoder state for continuous inference - """ - x = self.convolutions(x) - x = x.transpose(1, 2) - self.lstm.flatten_parameters() - outputs, self.rnn_state = self.lstm(x, self.rnn_state) - return outputs + o, _ = self.lstm(o) + return o # adapted from https://github.com/NVIDIA/tacotron2/ class Decoder(nn.Module): # Pylint gets confused by PyTorch conventions here #pylint: disable=attribute-defined-outside-init - def __init__(self, in_features, memory_dim, r, attn_type, attn_win, attn_norm, + def __init__(self, input_dim, frame_dim, r, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, separate_stopnet, speaker_embedding_dim): super(Decoder, self).__init__() - self.memory_dim = memory_dim + self.frame_dim = frame_dim self.r_init = r self.r = r - self.encoder_embedding_dim = in_features + self.encoder_embedding_dim = input_dim self.separate_stopnet = separate_stopnet + self.max_decoder_steps = 1000 + self.gate_threshold = 0.5 + + # model dimensions self.query_dim = 1024 self.decoder_rnn_dim = 1024 self.prenet_dim = 256 - self.max_decoder_steps = 1000 - self.gate_threshold = 0.5 + self.attn_dim = 128 self.p_attention_dropout = 0.1 self.p_decoder_dropout = 0.1 # memory -> |Prenet| -> processed_memory - prenet_dim = self.memory_dim - self.prenet = Prenet( - prenet_dim, - prenet_type, - prenet_dropout, - out_features=[self.prenet_dim, self.prenet_dim], - bias=False) + prenet_dim = self.frame_dim + self.prenet = Prenet(prenet_dim, + prenet_type, + prenet_dropout, + out_features=[self.prenet_dim, self.prenet_dim], + bias=False) - self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, + self.attention_rnn = nn.LSTMCell(self.prenet_dim + input_dim, self.query_dim) self.attention = init_attn(attn_type=attn_type, query_dim=self.query_dim, - embedding_dim=in_features, + embedding_dim=input_dim, attention_dim=128, location_attention=location_attn, attention_location_n_filters=32, @@ -141,15 +137,15 @@ class Decoder(nn.Module): forward_attn_mask=forward_attn_mask, attn_K=attn_K) - self.decoder_rnn = nn.LSTMCell(self.query_dim + in_features, + self.decoder_rnn = nn.LSTMCell(self.query_dim + input_dim, self.decoder_rnn_dim, 1) - self.linear_projection = Linear(self.decoder_rnn_dim + in_features, - self.memory_dim * self.r_init) + self.linear_projection = Linear(self.decoder_rnn_dim + input_dim, + self.frame_dim * self.r_init) self.stopnet = nn.Sequential( nn.Dropout(0.1), - Linear(self.decoder_rnn_dim + self.memory_dim * self.r_init, + Linear(self.decoder_rnn_dim + self.frame_dim * self.r_init, 1, bias=True, init_gain='sigmoid')) @@ -161,7 +157,7 @@ class Decoder(nn.Module): def get_go_frame(self, inputs): B = inputs.size(0) memory = torch.zeros(1, device=inputs.device).repeat(B, - self.memory_dim * self.r) + self.frame_dim * self.r) return memory def _init_states(self, inputs, mask, keep_states=False): @@ -187,9 +183,9 @@ class Decoder(nn.Module): Reshape the spectrograms for given 'r' """ # Grouping multiple frames if necessary - if memory.size(-1) == self.memory_dim: + if memory.size(-1) == self.frame_dim: memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1) - # Time first (T_decoder, B, memory_dim) + # Time first (T_decoder, B, frame_dim) memory = memory.transpose(0, 1) return memory @@ -197,22 +193,22 @@ class Decoder(nn.Module): alignments = torch.stack(alignments).transpose(0, 1) stop_tokens = torch.stack(stop_tokens).transpose(0, 1) outputs = torch.stack(outputs).transpose(0, 1).contiguous() - outputs = outputs.view(outputs.size(0), -1, self.memory_dim) + outputs = outputs.view(outputs.size(0), -1, self.frame_dim) outputs = outputs.transpose(1, 2) return outputs, stop_tokens, alignments def _update_memory(self, memory): if len(memory.shape) == 2: - return memory[:, self.memory_dim * (self.r - 1):] - return memory[:, :, self.memory_dim * (self.r - 1):] + return memory[:, self.frame_dim * (self.r - 1):] + return memory[:, :, self.frame_dim * (self.r - 1):] def decode(self, memory): ''' shapes: - - memory: B x r * self.memory_dim + - memory: B x r * self.frame_dim ''' # self.context: B x D_en - # query_input: B x D_en + (r * self.memory_dim) + # query_input: B x D_en + (r * self.frame_dim) query_input = torch.cat((memory, self.context), -1) # self.query and self.attention_rnn_cell_state : B x D_attn_rnn self.query, self.attention_rnn_cell_state = self.attention_rnn( @@ -235,16 +231,16 @@ class Decoder(nn.Module): # B x (D_decoder_rnn + D_en) decoder_hidden_context = torch.cat((self.decoder_hidden, self.context), dim=1) - # B x (self.r * self.memory_dim) + # B x (self.r * self.frame_dim) decoder_output = self.linear_projection(decoder_hidden_context) - # B x (D_decoder_rnn + (self.r * self.memory_dim)) + # B x (D_decoder_rnn + (self.r * self.frame_dim)) stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1) if self.separate_stopnet: stop_token = self.stopnet(stopnet_input.detach()) else: stop_token = self.stopnet(stopnet_input) # select outputs for the reduction rate self.r - decoder_output = decoder_output[:, :self.r * self.memory_dim] + decoder_output = decoder_output[:, :self.r * self.frame_dim] return decoder_output, self.attention.attention_weights, stop_token def forward(self, inputs, memories, mask, speaker_embeddings=None): diff --git a/models/tacotron2.py b/models/tacotron2.py index d530774a..3e7adfca 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -29,7 +29,7 @@ class Tacotron2(nn.Module): super(Tacotron2, self).__init__() self.postnet_output_dim = postnet_output_dim self.decoder_output_dim = decoder_output_dim - self.n_frames_per_step = r + self.r = r self.bidirectional_decoder = bidirectional_decoder decoder_dim = 512 if num_speakers > 1 else 512 encoder_dim = 512 if num_speakers > 1 else 512