coqui-tts/TTS/tts/layers/glow_tts/encoder.py

191 lines
7.7 KiB
Python

import math
import torch
from torch import nn
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.layers.generic.gated_conv import GatedConvBlock
from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.generic.time_depth_sep_conv import TimeDepthSeparableConvBlock
from TTS.tts.layers.generic.res_conv_bn import ResidualConvBNBlock
class Encoder(nn.Module):
"""Glow-TTS encoder module.
embedding -> <prenet> -> encoder_module -> <postnet> --> proj_mean
|
|-> proj_var
|
|-> concat -> duration_predictor
speaker_embed
Args:
num_chars (int): number of characters.
out_channels (int): number of output channels.
hidden_channels (int): encoder's embedding size.
hidden_channels_ffn (int): transformer's feed-forward channels.
kernel_size (int): kernel size for conv layers and duration predictor.
dropout_p (float): dropout rate for any dropout layer.
mean_only (bool): if True, output only mean values and use constant std.
use_prenet (bool): if True, use pre-convolutional layers before transformer layers.
c_in_channels (int): number of channels in conditional input.
Shapes:
- input: (B, T, C)
Notes:
suggested encoder params...
for encoder_type == 'rel_pos_transformer'
encoder_params={
'kernel_size':3,
'dropout_p': 0.1,
'num_layers': 6,
'num_heads': 2,
'hidden_channels_ffn': 768, # 4 times the hidden_channels
'input_length': None
}
for encoder_type == 'gated_conv'
encoder_params={
'kernel_size':5,
'dropout_p': 0.1,
'num_layers': 9,
}
for encoder_type == 'residual_conv_bn'
encoder_params={
"kernel_size": 4,
"dilations": [1, 2, 4, 1, 2, 4, 1, 2, 4, 1, 2, 4, 1],
"num_conv_blocks": 2,
"num_res_blocks": 13
}
for encoder_type == 'time_depth_separable'
encoder_params={
"kernel_size": 5,
'num_layers': 9,
}
"""
def __init__(self,
num_chars,
out_channels,
hidden_channels,
hidden_channels_dp,
encoder_type,
encoder_params,
dropout_p_dp=0.1,
mean_only=False,
use_prenet=True,
c_in_channels=0):
super().__init__()
# class arguments
self.num_chars = num_chars
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.hidden_channels_dp = hidden_channels_dp
self.dropout_p_dp = dropout_p_dp
self.mean_only = mean_only
self.use_prenet = use_prenet
self.c_in_channels = c_in_channels
self.encoder_type = encoder_type
# embedding layer
self.emb = nn.Embedding(num_chars, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
# init encoder
if encoder_type.lower() == "transformer":
# optional convolutional prenet
if use_prenet:
self.pre = ConvLayerNorm(hidden_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
num_layers=3,
dropout_p=0.5)
# text encoder
self.encoder = Transformer(
hidden_channels,
hidden_channels_ffn,
num_heads,
num_layers,
kernel_size=3,
dropout_p=dropout_p,
rel_attn_window_size=rel_attn_window_size,
input_length=input_length)
elif encoder_type.lower() == 'gated_conv':
self.encoder = GatedConvBlock(hidden_channels,
kernel_size=5,
dropout_p=dropout_p,
num_layers=3 + num_layers)
elif encoder_type.lower() == 'residual_conv_bn':
if use_prenet:
self.pre = nn.Sequential(
nn.Conv1d(hidden_channels, hidden_channels, 1),
nn.ReLU()
)
dilations = 4 * [1, 2, 4] + [1]
num_conv_blocks = 2
num_res_blocks = 13 # total 2 * 13 blocks
self.encoder = ResidualConvBNBlock(hidden_channels,
kernel_size=4,
dilations=dilations,
num_res_blocks=num_res_blocks,
num_conv_blocks=num_conv_blocks)
elif encoder_type.lower() == 'time_depth_separable':
# optional convolutional prenet
if use_prenet:
self.pre = ConvLayerNorm(hidden_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
num_layers=3,
dropout_p=0.5)
self.encoder = TimeDepthSeparableConvBlock(hidden_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
num_layers=3 + num_layers)
else:
raise ValueError(" [!] Unkown encoder type.")
# final projection layers
self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1)
if not mean_only:
self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1)
# duration predictor
self.duration_predictor = DurationPredictor(
hidden_channels + c_in_channels, hidden_channels_dp, 3,
dropout_p)
def forward(self, x, x_lengths, g=None):
# embedding layer
# [B ,T, D]
x = self.emb(x) * math.sqrt(self.hidden_channels)
# [B, D, T]
x = torch.transpose(x, 1, -1)
# compute input sequence mask
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)),
1).to(x.dtype)
# pre-conv layers
if hasattr(self, 'pre') and self.use_prenet:
x = self.pre(x, x_mask)
# encoder
x = self.encoder(x, x_mask)
# set duration predictor input
if g is not None:
g_exp = g.expand(-1, -1, x.size(-1))
x_dp = torch.cat([torch.detach(x), g_exp], 1)
else:
x_dp = torch.detach(x)
# final projection layer
x_m = self.proj_m(x) * x_mask
if not self.mean_only:
x_logs = self.proj_s(x) * x_mask
else:
x_logs = torch.zeros_like(x_m)
# duration predictor
logw = self.duration_predictor(x_dp, x_mask)
return x_m, x_logs, logw, x_mask