mirror of https://github.com/coqui-ai/TTS.git
graves attention [WIP]
This commit is contained in:
parent
537cd66f27
commit
84d81b6579
|
@ -105,6 +105,70 @@ class LocationLayer(nn.Module):
|
||||||
return processed_attention
|
return processed_attention
|
||||||
|
|
||||||
|
|
||||||
|
class GravesAttention(nn.Module):
|
||||||
|
COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi))
|
||||||
|
|
||||||
|
def __init__(self, query_dim, K, attention_alignment=0.05):
|
||||||
|
super(GravesAttention, self).__init__()
|
||||||
|
self._mask_value = -float("inf")
|
||||||
|
self.K = K
|
||||||
|
self.attention_alignment = attention_alignment
|
||||||
|
self.epsilon = 1e-5
|
||||||
|
self.J = None
|
||||||
|
self.N_a = nn.Sequential(
|
||||||
|
nn.Linear(query_dim, query_dim//2),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Linear(query_dim//2, 3*K))
|
||||||
|
self.mu_tm1 = None
|
||||||
|
|
||||||
|
def init_states(self, inputs):
|
||||||
|
if self.J is None or inputs.shape[1] > self.J.shape[-1]:
|
||||||
|
self.J = torch.arange(0, inputs.shape[1]).expand_as(torch.Tensor(inputs.shape[0], self.K, inputs.shape[1])).to(inputs.device)
|
||||||
|
self.mu_tm1 = torch.zeros(inputs.shape[0], self.K).to(inputs.device)
|
||||||
|
|
||||||
|
def forward(self, query, inputs, mask):
|
||||||
|
"""
|
||||||
|
shapes:
|
||||||
|
query: B x D_attention_rnn
|
||||||
|
inputs: B x T_in x D_encoder
|
||||||
|
mask: B x T_in
|
||||||
|
"""
|
||||||
|
gbk_t = self.N_a(query)
|
||||||
|
gbk_t = gbk_t.view(gbk_t.size(0), -1, self.K)
|
||||||
|
|
||||||
|
# attention model parameters
|
||||||
|
# each B x K
|
||||||
|
g_t = gbk_t[:, 0, :]
|
||||||
|
b_t = gbk_t[:, 1, :]
|
||||||
|
k_t = gbk_t[:, 2, :]
|
||||||
|
|
||||||
|
# attention GMM parameters
|
||||||
|
g_t = torch.softmax(g_t, dim=-1) + self.epsilon # distribution weight
|
||||||
|
sig_t = torch.exp(b_t) + self.epsilon # variance
|
||||||
|
mu_t = self.mu_tm1 + self.attention_alignment * torch.exp(k_t) # mean
|
||||||
|
|
||||||
|
g_t = g_t.unsqueeze(2).expand(g_t.size(0),
|
||||||
|
g_t.size(1),
|
||||||
|
inputs.size(1))
|
||||||
|
sig_t = sig_t.unsqueeze(2).expand_as(g_t)
|
||||||
|
mu_t_ = mu_t.unsqueeze(2).expand_as(g_t)
|
||||||
|
j = self.J[:g_t.size(0), :, :inputs.size(1)]
|
||||||
|
|
||||||
|
# attention weights
|
||||||
|
phi_t = g_t * torch.exp(-0.5 * sig_t * (mu_t_ - j)**2)
|
||||||
|
alpha_t = self.COEF * torch.sum(phi_t, 1)
|
||||||
|
|
||||||
|
# apply masking
|
||||||
|
# if mask is not None:
|
||||||
|
# alpha_t.data.masked_fill_(~mask, self._mask_value)
|
||||||
|
|
||||||
|
breakpoint()
|
||||||
|
|
||||||
|
c_t = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1)
|
||||||
|
self.mu_tm1 = mu_t
|
||||||
|
return c_t, mu_t, alpha_t
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
# Pylint gets confused by PyTorch conventions here
|
# Pylint gets confused by PyTorch conventions here
|
||||||
#pylint: disable=attribute-defined-outside-init
|
#pylint: disable=attribute-defined-outside-init
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from .common_layers import Prenet, Attention, Linear
|
from .common_layers import Prenet, Attention, Linear, GravesAttention
|
||||||
|
|
||||||
|
|
||||||
class BatchNormConv1d(nn.Module):
|
class BatchNormConv1d(nn.Module):
|
||||||
|
@ -288,17 +288,18 @@ class Decoder(nn.Module):
|
||||||
# attention_rnn generates queries for the attention mechanism
|
# attention_rnn generates queries for the attention mechanism
|
||||||
self.attention_rnn = nn.GRUCell(in_features + 128, self.query_dim)
|
self.attention_rnn = nn.GRUCell(in_features + 128, self.query_dim)
|
||||||
|
|
||||||
self.attention = Attention(query_dim=self.query_dim,
|
# self.attention = Attention(query_dim=self.query_dim,
|
||||||
embedding_dim=in_features,
|
# embedding_dim=in_features,
|
||||||
attention_dim=128,
|
# attention_dim=128,
|
||||||
location_attention=location_attn,
|
# location_attention=location_attn,
|
||||||
attention_location_n_filters=32,
|
# attention_location_n_filters=32,
|
||||||
attention_location_kernel_size=31,
|
# attention_location_kernel_size=31,
|
||||||
windowing=attn_windowing,
|
# windowing=attn_windowing,
|
||||||
norm=attn_norm,
|
# norm=attn_norm,
|
||||||
forward_attn=forward_attn,
|
# forward_attn=forward_attn,
|
||||||
trans_agent=trans_agent,
|
# trans_agent=trans_agent,
|
||||||
forward_attn_mask=forward_attn_mask)
|
# forward_attn_mask=forward_attn_mask)
|
||||||
|
self.attention = GravesAttention(self.query_dim, 5)
|
||||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||||
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
||||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||||
|
@ -342,7 +343,7 @@ class Decoder(nn.Module):
|
||||||
]
|
]
|
||||||
self.context_vec = inputs.data.new(B, self.in_features).zero_()
|
self.context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||||
# cache attention inputs
|
# cache attention inputs
|
||||||
self.processed_inputs = self.attention.inputs_layer(inputs)
|
# self.processed_inputs = self.attention.inputs_layer(inputs)
|
||||||
|
|
||||||
def _parse_outputs(self, outputs, attentions, stop_tokens):
|
def _parse_outputs(self, outputs, attentions, stop_tokens):
|
||||||
# Back to batch first
|
# Back to batch first
|
||||||
|
@ -362,7 +363,7 @@ class Decoder(nn.Module):
|
||||||
torch.cat((processed_memory, self.context_vec), -1),
|
torch.cat((processed_memory, self.context_vec), -1),
|
||||||
self.attention_rnn_hidden)
|
self.attention_rnn_hidden)
|
||||||
self.context_vec = self.attention(
|
self.context_vec = self.attention(
|
||||||
self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
|
self.attention_rnn_hidden, inputs, mask)
|
||||||
# Concat RNN output and attention context vector
|
# Concat RNN output and attention context vector
|
||||||
decoder_input = self.project_to_decoder_in(
|
decoder_input = self.project_to_decoder_in(
|
||||||
torch.cat((self.attention_rnn_hidden, self.context_vec), -1))
|
torch.cat((self.attention_rnn_hidden, self.context_vec), -1))
|
||||||
|
|
Loading…
Reference in New Issue