mirror of https://github.com/coqui-ai/TTS.git
glow ttss refactoring
This commit is contained in:
parent
29f4329d7f
commit
e82d31b6ac
|
@ -94,49 +94,32 @@ class Encoder(nn.Module):
|
|||
# embedding layer
|
||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
||||
# init encoder
|
||||
if encoder_type.lower() == "transformer":
|
||||
# optional convolutional prenet
|
||||
# init encoder module
|
||||
if encoder_type.lower() == "rel_pos_transformer":
|
||||
if use_prenet:
|
||||
self.pre = ConvLayerNorm(hidden_channels,
|
||||
self.prenet = ConvLayerNorm(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size=5,
|
||||
num_layers=3,
|
||||
dropout_p=0.5)
|
||||
# text encoder
|
||||
self.encoder = Transformer(
|
||||
hidden_channels,
|
||||
hidden_channels_ffn,
|
||||
num_heads,
|
||||
num_layers,
|
||||
kernel_size=3,
|
||||
dropout_p=dropout_p,
|
||||
rel_attn_window_size=rel_attn_window_size,
|
||||
input_length=input_length)
|
||||
self.encoder = RelativePositionTransformer(
|
||||
hidden_channels, **encoder_params)
|
||||
elif encoder_type.lower() == 'gated_conv':
|
||||
self.encoder = GatedConvBlock(hidden_channels,
|
||||
kernel_size=5,
|
||||
dropout_p=dropout_p,
|
||||
num_layers=3 + num_layers)
|
||||
self.encoder = GatedConvBlock(hidden_channels, **encoder_params)
|
||||
elif encoder_type.lower() == 'residual_conv_bn':
|
||||
if use_prenet:
|
||||
self.pre = nn.Sequential(
|
||||
self.prenet = nn.Sequential(
|
||||
nn.Conv1d(hidden_channels, hidden_channels, 1),
|
||||
nn.ReLU()
|
||||
)
|
||||
dilations = 4 * [1, 2, 4] + [1]
|
||||
num_conv_blocks = 2
|
||||
num_res_blocks = 13 # total 2 * 13 blocks
|
||||
self.encoder = ResidualConvBNBlock(hidden_channels,
|
||||
kernel_size=4,
|
||||
dilations=dilations,
|
||||
num_res_blocks=num_res_blocks,
|
||||
num_conv_blocks=num_conv_blocks)
|
||||
self.encoder = ResidualConvBNBlock(hidden_channels, **encoder_params)
|
||||
self.postnet = nn.Sequential(
|
||||
nn.Conv1d(self.hidden_channels, self.hidden_channels, 1),
|
||||
nn.BatchNorm1d(self.hidden_channels))
|
||||
elif encoder_type.lower() == 'time_depth_separable':
|
||||
# optional convolutional prenet
|
||||
if use_prenet:
|
||||
self.pre = ConvLayerNorm(hidden_channels,
|
||||
self.prenet = ConvLayerNorm(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size=5,
|
||||
|
@ -145,8 +128,7 @@ class Encoder(nn.Module):
|
|||
self.encoder = TimeDepthSeparableConvBlock(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size=5,
|
||||
num_layers=3 + num_layers)
|
||||
**encoder_params)
|
||||
else:
|
||||
raise ValueError(" [!] Unkown encoder type.")
|
||||
|
||||
|
@ -157,7 +139,7 @@ class Encoder(nn.Module):
|
|||
# duration predictor
|
||||
self.duration_predictor = DurationPredictor(
|
||||
hidden_channels + c_in_channels, hidden_channels_dp, 3,
|
||||
dropout_p)
|
||||
dropout_p_dp)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
# embedding layer
|
||||
|
@ -168,11 +150,14 @@ class Encoder(nn.Module):
|
|||
# compute input sequence mask
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)),
|
||||
1).to(x.dtype)
|
||||
# pre-conv layers
|
||||
if hasattr(self, 'pre') and self.use_prenet:
|
||||
x = self.pre(x, x_mask)
|
||||
# prenet
|
||||
if hasattr(self, 'prenet') and self.use_prenet:
|
||||
x = self.prenet(x, x_mask)
|
||||
# encoder
|
||||
x = self.encoder(x, x_mask)
|
||||
# postnet
|
||||
if hasattr(self, 'postnet'):
|
||||
x = self.postnet(x) * x_mask
|
||||
# set duration predictor input
|
||||
if g is not None:
|
||||
g_exp = g.expand(-1, -1, x.size(-1))
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from TTS.tts.layers.generic.wavenet import WN
|
||||
|
||||
from ..generic.normalization import LayerNorm
|
||||
|
||||
|
@ -50,104 +51,6 @@ class ConvLayerNorm(nn.Module):
|
|||
return x * x_mask
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||
n_channels_int = n_channels[0]
|
||||
in_act = input_a + input_b
|
||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
|
||||
|
||||
class WN(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
num_layers,
|
||||
c_in_channels=0,
|
||||
dropout_p=0):
|
||||
super().__init__()
|
||||
assert kernel_size % 2 == 1
|
||||
assert hidden_channels % 2 == 0
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.num_layers = num_layers
|
||||
self.c_in_channels = c_in_channels
|
||||
self.dropout_p = dropout_p
|
||||
|
||||
self.in_layers = torch.nn.ModuleList()
|
||||
self.res_skip_layers = torch.nn.ModuleList()
|
||||
self.dropout = nn.Dropout(dropout_p)
|
||||
|
||||
if c_in_channels > 0:
|
||||
cond_layer = torch.nn.Conv1d(c_in_channels,
|
||||
2 * hidden_channels * num_layers, 1)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer,
|
||||
name='weight')
|
||||
|
||||
for i in range(num_layers):
|
||||
dilation = dilation_rate**i
|
||||
padding = int((kernel_size * dilation - dilation) / 2)
|
||||
in_layer = torch.nn.Conv1d(hidden_channels,
|
||||
2 * hidden_channels,
|
||||
kernel_size,
|
||||
dilation=dilation,
|
||||
padding=padding)
|
||||
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
|
||||
self.in_layers.append(in_layer)
|
||||
|
||||
if i < num_layers - 1:
|
||||
res_skip_channels = 2 * hidden_channels
|
||||
else:
|
||||
res_skip_channels = hidden_channels
|
||||
|
||||
res_skip_layer = torch.nn.Conv1d(hidden_channels,
|
||||
res_skip_channels, 1)
|
||||
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer,
|
||||
name='weight')
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
|
||||
def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument
|
||||
output = torch.zeros_like(x)
|
||||
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
||||
|
||||
if g is not None:
|
||||
g = self.cond_layer(g)
|
||||
|
||||
for i in range(self.num_layers):
|
||||
x_in = self.in_layers[i](x)
|
||||
x_in = self.dropout(x_in)
|
||||
if g is not None:
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:,
|
||||
cond_offset:cond_offset + 2 * self.hidden_channels, :]
|
||||
else:
|
||||
g_l = torch.zeros_like(x_in)
|
||||
|
||||
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l,
|
||||
n_channels_tensor)
|
||||
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
if i < self.num_layers - 1:
|
||||
x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask
|
||||
output = output + res_skip_acts[:, self.hidden_channels:, :]
|
||||
else:
|
||||
output = output + res_skip_acts
|
||||
return output * x_mask
|
||||
|
||||
def remove_weight_norm(self):
|
||||
if self.c_in_channels != 0:
|
||||
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
||||
for l in self.in_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
for l in self.res_skip_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
|
||||
class InvConvNear(nn.Module):
|
||||
def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument
|
||||
super().__init__()
|
||||
|
@ -166,7 +69,7 @@ class InvConvNear(nn.Module):
|
|||
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
|
||||
"""Split the input into groups of size self.num_splits and
|
||||
perform 1x1 convolution separately. Cast 1x1 conv operation
|
||||
to 2d by reshaping the input for efficienty.
|
||||
to 2d by reshaping the input for efficiency.
|
||||
"""
|
||||
|
||||
b, c, t = x.size()
|
||||
|
|
|
@ -81,7 +81,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
# compute raw attention scores
|
||||
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
|
||||
self.k_channels)
|
||||
# relative positional encoding
|
||||
# relative positional encoding for scores
|
||||
if self.rel_attn_window_size is not None:
|
||||
assert t_s == t_t, "Relative attention is only available for self-attention."
|
||||
# get relative key embeddings
|
||||
|
@ -262,7 +262,7 @@ class FFN(nn.Module):
|
|||
return x * x_mask
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
class RelativePositionTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
hidden_channels,
|
||||
hidden_channels_ffn,
|
||||
|
|
|
@ -10,44 +10,59 @@ from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path
|
|||
|
||||
|
||||
class GlowTts(nn.Module):
|
||||
"""Glow TTS models from https://arxiv.org/abs/2005.11129"""
|
||||
"""Glow TTS models from https://arxiv.org/abs/2005.11129
|
||||
|
||||
Args:
|
||||
num_chars (int): number of embedding characters.
|
||||
hidden_channels_enc (int): number of embedding and encoder channels.
|
||||
hidden_channels_dec (int): number of decoder channels.
|
||||
use_encoder_prenet (bool): enable/disable prenet for encoder. Prenet modules are hard-coded for each alternative encoder.
|
||||
hidden_channels_dp (int): number of duration predictor channels.
|
||||
out_channels (int): number of output channels. It should be equal to the number of spectrogram filter.
|
||||
num_flow_blocks_dec (int): number of decoder blocks.
|
||||
kernel_size_dec (int): decoder kernel size.
|
||||
dilation_rate (int): rate to increase dilation by each layer in a decoder block.
|
||||
num_block_layers (int): number of decoder layers in each decoder block.
|
||||
dropout_p_dec (float): dropout rate for decoder.
|
||||
num_speaker (int): number of speaker to define the size of speaker embedding layer.
|
||||
c_in_channels (int): number of speaker embedding channels. It is set to 512 if embeddings are learned.
|
||||
num_splits (int): number of split levels in inversible conv1x1 operation.
|
||||
num_squeeze (int): number of squeeze levels. When squeezing channels increases and time steps reduces by the factor 'num_squeeze'.
|
||||
sigmoid_scale (bool): enable/disable sigmoid scaling in decoder.
|
||||
mean_only (bool): if True, encoder only computes mean value and uses constant variance for each time step.
|
||||
encoder_type (str): encoder module type.
|
||||
encoder_params (dict): encoder module parameters.
|
||||
external_speaker_embedding_dim (int): channels of external speaker embedding vectors.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
hidden_channels,
|
||||
hidden_channels_ffn,
|
||||
hidden_channels_enc,
|
||||
hidden_channels_dec,
|
||||
use_encoder_prenet,
|
||||
hidden_channels_dp,
|
||||
out_channels,
|
||||
num_heads=2,
|
||||
num_layers_enc=6,
|
||||
dropout_p=0.1,
|
||||
num_flow_blocks_dec=12,
|
||||
kernel_size_dec=5,
|
||||
dilation_rate=5,
|
||||
num_block_layers=4,
|
||||
dropout_p_dec=0.,
|
||||
dropout_p_dp=0.1,
|
||||
dropout_p_dec=0.05,
|
||||
num_speakers=0,
|
||||
c_in_channels=0,
|
||||
num_splits=4,
|
||||
num_sqz=1,
|
||||
num_squeeze=1,
|
||||
sigmoid_scale=False,
|
||||
rel_attn_window_size=None,
|
||||
input_length=None,
|
||||
mean_only=False,
|
||||
hidden_channels_enc=None,
|
||||
hidden_channels_dec=None,
|
||||
use_encoder_prenet=False,
|
||||
encoder_type="transformer",
|
||||
encoder_params=None,
|
||||
external_speaker_embedding_dim=None):
|
||||
|
||||
super().__init__()
|
||||
self.num_chars = num_chars
|
||||
self.hidden_channels = hidden_channels
|
||||
self.hidden_channels_ffn = hidden_channels_ffn
|
||||
self.hidden_channels_dp = hidden_channels_dp
|
||||
self.hidden_channels_enc = hidden_channels_enc
|
||||
self.hidden_channels_dec = hidden_channels_dec
|
||||
self.out_channels = out_channels
|
||||
self.num_heads = num_heads
|
||||
self.num_layers_enc = num_layers_enc
|
||||
self.dropout_p = dropout_p
|
||||
self.num_flow_blocks_dec = num_flow_blocks_dec
|
||||
self.kernel_size_dec = kernel_size_dec
|
||||
self.dilation_rate = dilation_rate
|
||||
|
@ -56,16 +71,14 @@ class GlowTts(nn.Module):
|
|||
self.num_speakers = num_speakers
|
||||
self.c_in_channels = c_in_channels
|
||||
self.num_splits = num_splits
|
||||
self.num_sqz = num_sqz
|
||||
self.num_squeeze = num_squeeze
|
||||
self.sigmoid_scale = sigmoid_scale
|
||||
self.rel_attn_window_size = rel_attn_window_size
|
||||
self.input_length = input_length
|
||||
self.mean_only = mean_only
|
||||
self.hidden_channels_enc = hidden_channels_enc
|
||||
self.hidden_channels_dec = hidden_channels_dec
|
||||
self.use_encoder_prenet = use_encoder_prenet
|
||||
self.noise_scale = 0.66
|
||||
self.length_scale = 1.
|
||||
|
||||
# model constants.
|
||||
self.noise_scale = 0.33 # defines the noise variance applied to the random z vector at inference.
|
||||
self.length_scale = 1. # scaler for the duration predictor. The larger it is, the slower the speech.
|
||||
self.external_speaker_embedding_dim = external_speaker_embedding_dim
|
||||
|
||||
# if is a multispeaker and c_in_channels is 0, set to 256
|
||||
|
@ -77,27 +90,24 @@ class GlowTts(nn.Module):
|
|||
|
||||
self.encoder = Encoder(num_chars,
|
||||
out_channels=out_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
hidden_channels_ffn=hidden_channels_ffn,
|
||||
hidden_channels=hidden_channels_enc,
|
||||
hidden_channels_dp=hidden_channels_dp,
|
||||
encoder_type=encoder_type,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers_enc,
|
||||
dropout_p=dropout_p,
|
||||
rel_attn_window_size=rel_attn_window_size,
|
||||
encoder_params=encoder_params,
|
||||
mean_only=mean_only,
|
||||
use_prenet=use_encoder_prenet,
|
||||
dropout_p_dp=dropout_p_dp,
|
||||
c_in_channels=self.c_in_channels)
|
||||
|
||||
self.decoder = Decoder(out_channels,
|
||||
hidden_channels_dec or hidden_channels,
|
||||
hidden_channels_dec,
|
||||
kernel_size_dec,
|
||||
dilation_rate,
|
||||
num_flow_blocks_dec,
|
||||
num_block_layers,
|
||||
dropout_p=dropout_p_dec,
|
||||
num_splits=num_splits,
|
||||
num_sqz=num_sqz,
|
||||
num_squeeze=num_squeeze,
|
||||
sigmoid_scale=sigmoid_scale,
|
||||
c_in_channels=self.c_in_channels)
|
||||
|
||||
|
@ -140,7 +150,7 @@ class GlowTts(nn.Module):
|
|||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
||||
x_lengths,
|
||||
g=g)
|
||||
# drop redisual frames wrt num_sqz and set y_lengths.
|
||||
# drop redisual frames wrt num_squeeze and set y_lengths.
|
||||
y, y_lengths, y_max_length, attn = self.preprocess(
|
||||
y, y_lengths, y_max_length, None)
|
||||
# create masks
|
||||
|
@ -195,6 +205,7 @@ class GlowTts(nn.Module):
|
|||
attn_mask.squeeze(1)).unsqueeze(1)
|
||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
|
||||
attn, o_mean, o_log_scale, x_mask)
|
||||
|
||||
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) *
|
||||
self.noise_scale) * y_mask
|
||||
# decoder pass
|
||||
|
@ -204,11 +215,11 @@ class GlowTts(nn.Module):
|
|||
|
||||
def preprocess(self, y, y_lengths, y_max_length, attn=None):
|
||||
if y_max_length is not None:
|
||||
y_max_length = (y_max_length // self.num_sqz) * self.num_sqz
|
||||
y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze
|
||||
y = y[:, :, :y_max_length]
|
||||
if attn is not None:
|
||||
attn = attn[:, :, :, :y_max_length]
|
||||
y_lengths = (y_lengths // self.num_sqz) * self.num_sqz
|
||||
y_lengths = (y_lengths // self.num_squeeze) * self.num_squeeze
|
||||
return y, y_lengths, y_max_length, attn
|
||||
|
||||
def store_inverse(self):
|
||||
|
|
|
@ -103,15 +103,12 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
speaker_embedding_dim=speaker_embedding_dim)
|
||||
elif c.model.lower() == "glow_tts":
|
||||
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
hidden_channels=192,
|
||||
hidden_channels_ffn=768,
|
||||
hidden_channels_dp=256,
|
||||
hidden_channels_enc=c['hidden_channels_encoder'],
|
||||
hidden_channels_dec=c['hidden_channels_decoder'],
|
||||
hidden_channels_dp=c['hidden_channels_duration_predictor'],
|
||||
out_channels=c.audio['num_mels'],
|
||||
num_heads=2,
|
||||
num_layers_enc=6,
|
||||
encoder_type=c.encoder_type,
|
||||
rel_attn_window_size=4,
|
||||
dropout_p=0.1,
|
||||
encoder_params=c.encoder_params,
|
||||
num_flow_blocks_dec=12,
|
||||
kernel_size_dec=5,
|
||||
dilation_rate=1,
|
||||
|
|
Loading…
Reference in New Issue