mirror of https://github.com/coqui-ai/TTS.git
Add `AlignerNetwork`
This commit is contained in:
parent
648655fa03
commit
59b24e66cf
|
@ -0,0 +1,81 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class AlignmentNetwork(torch.nn.Module):
|
||||
"""Aligner Network for learning alignment between the input text and the model output with Gaussian Attention.
|
||||
|
||||
::
|
||||
|
||||
query -> conv1d -> relu -> conv1d -> relu -> conv1d -> L2_dist -> softmax -> alignment
|
||||
key -> conv1d -> relu -> conv1d -----------------------^
|
||||
|
||||
Args:
|
||||
in_query_channels (int): Number of channels in the query network. Defaults to 80.
|
||||
in_key_channels (int): Number of channels in the key network. Defaults to 512.
|
||||
attn_channels (int): Number of inner channels in the attention layers. Defaults to 80.
|
||||
temperature (float): Temperature for the softmax. Defaults to 0.0005.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_query_channels=80,
|
||||
in_key_channels=512,
|
||||
attn_channels=80,
|
||||
temperature=0.0005,
|
||||
):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.softmax = torch.nn.Softmax(dim=3)
|
||||
self.log_softmax = torch.nn.LogSoftmax(dim=3)
|
||||
|
||||
self.key_layer = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
in_key_channels,
|
||||
in_key_channels * 2,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=True,
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True),
|
||||
)
|
||||
|
||||
self.query_layer = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
in_query_channels,
|
||||
in_query_channels * 2,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=True,
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
nn.Conv1d(in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True),
|
||||
torch.nn.ReLU(),
|
||||
nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None
|
||||
) -> Tuple[torch.tensor, torch.tensor]:
|
||||
"""Forward pass of the aligner encoder.
|
||||
Shapes:
|
||||
- queries: :math:`[B, C, T_de]`
|
||||
- keys: :math:`[B, C_emb, T_en]`
|
||||
- mask: :math:`[B, T_de]`
|
||||
Output:
|
||||
attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask.
|
||||
attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities.
|
||||
"""
|
||||
key_out = self.key_layer(keys)
|
||||
query_out = self.query_layer(queries)
|
||||
attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2
|
||||
attn_logp = -self.temperature * attn_factor.sum(1, keepdim=True)
|
||||
if attn_prior is not None:
|
||||
attn_logp = self.log_softmax(attn_logp) + torch.log(attn_prior[:, None] + 1e-8)
|
||||
if mask is not None:
|
||||
attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf"))
|
||||
attn = self.softmax(attn_logp)
|
||||
return attn, attn_logp
|
Loading…
Reference in New Issue