mirror of https://github.com/coqui-ai/TTS.git
gst_layers
This commit is contained in:
parent
7410daceb2
commit
88575edd5a
|
@ -0,0 +1,168 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class GST(nn.Module):
|
||||
"""Global Style Token Module for factorizing prosody in speech.
|
||||
|
||||
See https://arxiv.org/pdf/1803.09017"""
|
||||
|
||||
def __init__(self, num_mel, num_heads, num_style_tokens, embedding_dim):
|
||||
super().__init__()
|
||||
self.encoder = ReferenceEncoder(num_mel, embedding_dim)
|
||||
self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens,
|
||||
embedding_dim)
|
||||
|
||||
def forward(self, inputs):
|
||||
enc_out = self.encoder(inputs)
|
||||
style_embed = self.style_token_layer(enc_out)
|
||||
|
||||
return style_embed
|
||||
|
||||
|
||||
class ReferenceEncoder(nn.Module):
|
||||
"""NN module creating a fixed size prosody embedding from a spectrogram.
|
||||
|
||||
inputs: mel spectrograms [batch_size, num_spec_frames, num_mel]
|
||||
outputs: [batch_size, embedding_dim]
|
||||
"""
|
||||
|
||||
def __init__(self, num_mel, embedding_dim):
|
||||
|
||||
super().__init__()
|
||||
self.num_mel = num_mel
|
||||
filters = [1] + [32, 32, 64, 64, 128, 128]
|
||||
num_layers = len(filters) - 1
|
||||
convs = [
|
||||
nn.Conv2d(
|
||||
in_channels=filters[i],
|
||||
out_channels=filters[i + 1],
|
||||
kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
padding=(1, 1)) for i in range(num_layers)
|
||||
]
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList([
|
||||
nn.BatchNorm2d(num_features=filter_size)
|
||||
for filter_size in filters[1:]
|
||||
])
|
||||
|
||||
post_conv_height = self.calculate_post_conv_height(
|
||||
num_mel, 3, 2, 1, num_layers)
|
||||
self.recurrence = nn.GRU(
|
||||
input_size=filters[-1] * post_conv_height,
|
||||
hidden_size=embedding_dim // 2,
|
||||
batch_first=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
batch_size = inputs.size(0)
|
||||
x = inputs.view(batch_size, 1, -1, self.num_mel)
|
||||
# x: 4D tensor [batch_size, num_channels==1, num_frames, num_mel]
|
||||
for conv, bn in zip(self.convs, self.bns):
|
||||
x = conv(x)
|
||||
x = bn(x)
|
||||
x = F.relu(x)
|
||||
|
||||
x = x.transpose(1, 2)
|
||||
# x: 4D tensor [batch_size, post_conv_width,
|
||||
# num_channels==128, post_conv_height]
|
||||
post_conv_width = x.size(1)
|
||||
x = x.contiguous().view(batch_size, post_conv_width, -1)
|
||||
# x: 3D tensor [batch_size, post_conv_width,
|
||||
# num_channels*post_conv_height]
|
||||
self.recurrence.flatten_parameters()
|
||||
memory, out = self.recurrence(x)
|
||||
# out: 3D tensor [seq_len==1, batch_size, encoding_size=128]
|
||||
|
||||
return out.squeeze(0)
|
||||
|
||||
def calculate_post_conv_height(self, height, kernel_size, stride, pad,
|
||||
n_convs):
|
||||
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
|
||||
for i in range(n_convs):
|
||||
height = (height - kernel_size + 2 * pad) // stride + 1
|
||||
return height
|
||||
|
||||
|
||||
class StyleTokenLayer(nn.Module):
|
||||
"""NN Module attending to style tokens based on prosody encodings."""
|
||||
|
||||
def __init__(self, num_heads, num_style_tokens,
|
||||
embedding_dim):
|
||||
super().__init__()
|
||||
self.query_dim = embedding_dim // 2
|
||||
self.key_dim = embedding_dim // num_heads
|
||||
self.style_tokens = nn.Parameter(
|
||||
torch.FloatTensor(num_style_tokens, self.key_dim))
|
||||
nn.init.orthogonal_(self.style_tokens)
|
||||
self.attention = MultiHeadAttention(
|
||||
query_dim=self.query_dim,
|
||||
key_dim=self.key_dim,
|
||||
num_units=embedding_dim,
|
||||
num_heads=num_heads)
|
||||
|
||||
def forward(self, inputs):
|
||||
batch_size = inputs.size(0)
|
||||
prosody_encoding = inputs.unsqueeze(1)
|
||||
# prosody_encoding: 3D tensor [batch_size, 1, encoding_size==128]
|
||||
tokens = torch.tanh(self.style_tokens) \
|
||||
.unsqueeze(0) \
|
||||
.expand(batch_size, -1, -1)
|
||||
# tokens: 3D tensor [batch_size, num tokens, token embedding size]
|
||||
style_embed = self.attention(prosody_encoding, tokens)
|
||||
|
||||
return style_embed
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
'''
|
||||
input:
|
||||
query --- [N, T_q, query_dim]
|
||||
key --- [N, T_k, key_dim]
|
||||
output:
|
||||
out --- [N, T_q, num_units]
|
||||
'''
|
||||
|
||||
def __init__(self, query_dim, key_dim, num_units, num_heads):
|
||||
|
||||
super().__init__()
|
||||
self.num_units = num_units
|
||||
self.num_heads = num_heads
|
||||
self.key_dim = key_dim
|
||||
|
||||
self.W_query = nn.Linear(
|
||||
in_features=query_dim, out_features=num_units, bias=False)
|
||||
self.W_key = nn.Linear(
|
||||
in_features=key_dim, out_features=num_units, bias=False)
|
||||
self.W_value = nn.Linear(
|
||||
in_features=key_dim, out_features=num_units, bias=False)
|
||||
|
||||
def forward(self, query, key):
|
||||
queries = self.W_query(query) # [N, T_q, num_units]
|
||||
keys = self.W_key(key) # [N, T_k, num_units]
|
||||
values = self.W_value(key)
|
||||
|
||||
split_size = self.num_units // self.num_heads
|
||||
queries = torch.stack(
|
||||
torch.split(queries, split_size, dim=2),
|
||||
dim=0) # [h, N, T_q, num_units/h]
|
||||
keys = torch.stack(
|
||||
torch.split(keys, split_size, dim=2),
|
||||
dim=0) # [h, N, T_k, num_units/h]
|
||||
values = torch.stack(
|
||||
torch.split(values, split_size, dim=2),
|
||||
dim=0) # [h, N, T_k, num_units/h]
|
||||
|
||||
# score = softmax(QK^T / (d_k ** 0.5))
|
||||
scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k]
|
||||
scores = scores / (self.key_dim**0.5)
|
||||
scores = F.softmax(scores, dim=3)
|
||||
|
||||
# out = score * V
|
||||
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
|
||||
out = torch.cat(
|
||||
torch.split(out, 1, dim=0),
|
||||
dim=3).squeeze(0) # [N, T_q, num_units]
|
||||
|
||||
return out
|
Loading…
Reference in New Issue