From 88575edd5a5ea4e3ea8743e42b3ae6899e007a5e Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Wed, 5 Jun 2019 18:34:48 +0200 Subject: [PATCH] gst_layers --- layers/gst_layers.py | 168 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 layers/gst_layers.py diff --git a/layers/gst_layers.py b/layers/gst_layers.py new file mode 100644 index 00000000..647712d7 --- /dev/null +++ b/layers/gst_layers.py @@ -0,0 +1,168 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GST(nn.Module): + """Global Style Token Module for factorizing prosody in speech. + + See https://arxiv.org/pdf/1803.09017""" + + def __init__(self, num_mel, num_heads, num_style_tokens, embedding_dim): + super().__init__() + self.encoder = ReferenceEncoder(num_mel, embedding_dim) + self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, + embedding_dim) + + def forward(self, inputs): + enc_out = self.encoder(inputs) + style_embed = self.style_token_layer(enc_out) + + return style_embed + + +class ReferenceEncoder(nn.Module): + """NN module creating a fixed size prosody embedding from a spectrogram. + + inputs: mel spectrograms [batch_size, num_spec_frames, num_mel] + outputs: [batch_size, embedding_dim] + """ + + def __init__(self, num_mel, embedding_dim): + + super().__init__() + self.num_mel = num_mel + filters = [1] + [32, 32, 64, 64, 128, 128] + num_layers = len(filters) - 1 + convs = [ + nn.Conv2d( + in_channels=filters[i], + out_channels=filters[i + 1], + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1)) for i in range(num_layers) + ] + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList([ + nn.BatchNorm2d(num_features=filter_size) + for filter_size in filters[1:] + ]) + + post_conv_height = self.calculate_post_conv_height( + num_mel, 3, 2, 1, num_layers) + self.recurrence = nn.GRU( + input_size=filters[-1] * post_conv_height, + hidden_size=embedding_dim // 2, + batch_first=True) + + def forward(self, inputs): + batch_size = inputs.size(0) + x = inputs.view(batch_size, 1, -1, self.num_mel) + # x: 4D tensor [batch_size, num_channels==1, num_frames, num_mel] + for conv, bn in zip(self.convs, self.bns): + x = conv(x) + x = bn(x) + x = F.relu(x) + + x = x.transpose(1, 2) + # x: 4D tensor [batch_size, post_conv_width, + # num_channels==128, post_conv_height] + post_conv_width = x.size(1) + x = x.contiguous().view(batch_size, post_conv_width, -1) + # x: 3D tensor [batch_size, post_conv_width, + # num_channels*post_conv_height] + self.recurrence.flatten_parameters() + memory, out = self.recurrence(x) + # out: 3D tensor [seq_len==1, batch_size, encoding_size=128] + + return out.squeeze(0) + + def calculate_post_conv_height(self, height, kernel_size, stride, pad, + n_convs): + """Height of spec after n convolutions with fixed kernel/stride/pad.""" + for i in range(n_convs): + height = (height - kernel_size + 2 * pad) // stride + 1 + return height + + +class StyleTokenLayer(nn.Module): + """NN Module attending to style tokens based on prosody encodings.""" + + def __init__(self, num_heads, num_style_tokens, + embedding_dim): + super().__init__() + self.query_dim = embedding_dim // 2 + self.key_dim = embedding_dim // num_heads + self.style_tokens = nn.Parameter( + torch.FloatTensor(num_style_tokens, self.key_dim)) + nn.init.orthogonal_(self.style_tokens) + self.attention = MultiHeadAttention( + query_dim=self.query_dim, + key_dim=self.key_dim, + num_units=embedding_dim, + num_heads=num_heads) + + def forward(self, inputs): + batch_size = inputs.size(0) + prosody_encoding = inputs.unsqueeze(1) + # prosody_encoding: 3D tensor [batch_size, 1, encoding_size==128] + tokens = torch.tanh(self.style_tokens) \ + .unsqueeze(0) \ + .expand(batch_size, -1, -1) + # tokens: 3D tensor [batch_size, num tokens, token embedding size] + style_embed = self.attention(prosody_encoding, tokens) + + return style_embed + + +class MultiHeadAttention(nn.Module): + ''' + input: + query --- [N, T_q, query_dim] + key --- [N, T_k, key_dim] + output: + out --- [N, T_q, num_units] + ''' + + def __init__(self, query_dim, key_dim, num_units, num_heads): + + super().__init__() + self.num_units = num_units + self.num_heads = num_heads + self.key_dim = key_dim + + self.W_query = nn.Linear( + in_features=query_dim, out_features=num_units, bias=False) + self.W_key = nn.Linear( + in_features=key_dim, out_features=num_units, bias=False) + self.W_value = nn.Linear( + in_features=key_dim, out_features=num_units, bias=False) + + def forward(self, query, key): + queries = self.W_query(query) # [N, T_q, num_units] + keys = self.W_key(key) # [N, T_k, num_units] + values = self.W_value(key) + + split_size = self.num_units // self.num_heads + queries = torch.stack( + torch.split(queries, split_size, dim=2), + dim=0) # [h, N, T_q, num_units/h] + keys = torch.stack( + torch.split(keys, split_size, dim=2), + dim=0) # [h, N, T_k, num_units/h] + values = torch.stack( + torch.split(values, split_size, dim=2), + dim=0) # [h, N, T_k, num_units/h] + + # score = softmax(QK^T / (d_k ** 0.5)) + scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k] + scores = scores / (self.key_dim**0.5) + scores = F.softmax(scores, dim=3) + + # out = score * V + out = torch.matmul(scores, values) # [h, N, T_q, num_units/h] + out = torch.cat( + torch.split(out, 1, dim=0), + dim=3).squeeze(0) # [N, T_q, num_units] + + return out \ No newline at end of file