mirror of https://github.com/coqui-ai/TTS.git
add gated conv encoder to glow-tts
This commit is contained in:
parent
14356d3250
commit
1b238f04b2
|
@ -110,6 +110,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
kernel_size=3,
|
||||
num_heads=2,
|
||||
num_layers_enc=6,
|
||||
encoder_type=c.encoder_type,
|
||||
dropout_p=0.1,
|
||||
num_flow_blocks_dec=12,
|
||||
kernel_size_dec=5,
|
||||
|
|
|
@ -4,14 +4,53 @@ from torch import nn
|
|||
|
||||
from mozilla_voice_tts.tts.layers.glow_tts.transformer import Transformer
|
||||
from mozilla_voice_tts.tts.utils.generic_utils import sequence_mask
|
||||
from mozilla_voice_tts.tts.layers.glow_tts.glow import ConvLayerNorm
|
||||
from mozilla_voice_tts.tts.layers.glow_tts.glow import ConvLayerNorm, LayerNorm
|
||||
from mozilla_voice_tts.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
|
||||
|
||||
class GatedConvBlock(nn.Module):
|
||||
"""Gated convolutional block as in https://arxiv.org/pdf/1612.08083.pdf
|
||||
Args:
|
||||
in_out_channels (int): number of input/output channels.
|
||||
kernel_size (int): convolution kernel size.
|
||||
dropout_p (float): dropout rate.
|
||||
"""
|
||||
def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers):
|
||||
super().__init__()
|
||||
# class arguments
|
||||
self.dropout_p = dropout_p
|
||||
self.num_layers = num_layers
|
||||
# define layers
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.norm_layers = nn.ModuleList()
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(num_layers):
|
||||
self.conv_layers += [
|
||||
nn.Conv1d(in_out_channels,
|
||||
2 * in_out_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2)
|
||||
]
|
||||
self.norm_layers += [LayerNorm(2 * in_out_channels)]
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
o = x
|
||||
res = x
|
||||
for idx in range(self.num_layers):
|
||||
o = nn.functional.dropout(o,
|
||||
p=self.dropout_p,
|
||||
training=self.training)
|
||||
o = self.conv_layers[idx](o * x_mask)
|
||||
o = self.norm_layers[idx](o)
|
||||
o = nn.functional.glu(o, dim=1)
|
||||
o = res + o
|
||||
res = o
|
||||
return o
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Glow-TTS encoder module. We use Pytorch TransformerEncoder instead
|
||||
of the one with relative position embedding. We use positional encoding
|
||||
for capturing positiong information.
|
||||
"""Glow-TTS encoder module. It uses Transformer with Relative Pos.Encoding
|
||||
as in the original paper or GatedConvBlock as a faster alternative.
|
||||
|
||||
Args:
|
||||
num_chars (int): number of characters.
|
||||
|
@ -29,13 +68,13 @@ class Encoder(nn.Module):
|
|||
Shapes:
|
||||
- input: (B, T, C)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
filter_channels_dp,
|
||||
encoder_type,
|
||||
num_heads,
|
||||
num_layers,
|
||||
kernel_size,
|
||||
|
@ -59,26 +98,36 @@ class Encoder(nn.Module):
|
|||
self.mean_only = mean_only
|
||||
self.use_prenet = use_prenet
|
||||
self.c_in_channels = c_in_channels
|
||||
self.encoder_type = encoder_type
|
||||
# embedding layer
|
||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
||||
# optional convolutional prenet
|
||||
if use_prenet:
|
||||
self.pre = ConvLayerNorm(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size=5,
|
||||
num_layers=3,
|
||||
dropout_p=0.5)
|
||||
# text encoder
|
||||
self.encoder = Transformer(hidden_channels,
|
||||
filter_channels,
|
||||
num_heads,
|
||||
num_layers,
|
||||
kernel_size=kernel_size,
|
||||
dropout_p=dropout_p,
|
||||
rel_attn_window_size=rel_attn_window_size,
|
||||
input_length=input_length)
|
||||
# init encoder
|
||||
if encoder_type.lower() == "transformer":
|
||||
# optional convolutional prenet
|
||||
if use_prenet:
|
||||
self.pre = ConvLayerNorm(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size=5,
|
||||
num_layers=3,
|
||||
dropout_p=0.5)
|
||||
# text encoder
|
||||
self.encoder = Transformer(
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
num_heads,
|
||||
num_layers,
|
||||
kernel_size=kernel_size,
|
||||
dropout_p=dropout_p,
|
||||
rel_attn_window_size=rel_attn_window_size,
|
||||
input_length=input_length)
|
||||
elif encoder_type.lower() == 'gatedconv':
|
||||
breakpoint()
|
||||
self.encoder = GatedConvBlock(hidden_channels,
|
||||
kernel_size=5,
|
||||
dropout_p=dropout_p,
|
||||
num_layers=3 + num_layers)
|
||||
# final projection layers
|
||||
self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
if not mean_only:
|
||||
|
@ -98,8 +147,9 @@ class Encoder(nn.Module):
|
|||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)),
|
||||
1).to(x.dtype)
|
||||
# pre-conv layers
|
||||
if self.use_prenet:
|
||||
x = self.pre(x, x_mask)
|
||||
if self.encoder_type == 'transformer':
|
||||
if self.use_prenet:
|
||||
x = self.pre(x, x_mask)
|
||||
# encoder
|
||||
x = self.encoder(x, x_mask)
|
||||
# set duration predictor input
|
||||
|
|
|
@ -36,7 +36,8 @@ class GlowTts(nn.Module):
|
|||
mean_only=False,
|
||||
hidden_channels_enc=None,
|
||||
hidden_channels_dec=None,
|
||||
use_encoder_prenet=False):
|
||||
use_encoder_prenet=False,
|
||||
encoder_type="transformer"):
|
||||
|
||||
super().__init__()
|
||||
self.num_chars = num_chars
|
||||
|
@ -72,6 +73,7 @@ class GlowTts(nn.Module):
|
|||
hidden_channels_enc or hidden_channels,
|
||||
filter_channels,
|
||||
filter_channels_dp,
|
||||
encoder_type,
|
||||
num_heads,
|
||||
num_layers_enc,
|
||||
kernel_size,
|
||||
|
|
Loading…
Reference in New Issue