glow ttss refactoring

This commit is contained in:
erogol 2021-01-05 14:29:45 +01:00
parent 29f4329d7f
commit e82d31b6ac
5 changed files with 75 additions and 179 deletions

View File

@ -94,49 +94,32 @@ class Encoder(nn.Module):
# embedding layer
self.emb = nn.Embedding(num_chars, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
# init encoder
if encoder_type.lower() == "transformer":
# optional convolutional prenet
# init encoder module
if encoder_type.lower() == "rel_pos_transformer":
if use_prenet:
self.pre = ConvLayerNorm(hidden_channels,
self.prenet = ConvLayerNorm(hidden_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
num_layers=3,
dropout_p=0.5)
# text encoder
self.encoder = Transformer(
hidden_channels,
hidden_channels_ffn,
num_heads,
num_layers,
kernel_size=3,
dropout_p=dropout_p,
rel_attn_window_size=rel_attn_window_size,
input_length=input_length)
self.encoder = RelativePositionTransformer(
hidden_channels, **encoder_params)
elif encoder_type.lower() == 'gated_conv':
self.encoder = GatedConvBlock(hidden_channels,
kernel_size=5,
dropout_p=dropout_p,
num_layers=3 + num_layers)
self.encoder = GatedConvBlock(hidden_channels, **encoder_params)
elif encoder_type.lower() == 'residual_conv_bn':
if use_prenet:
self.pre = nn.Sequential(
self.prenet = nn.Sequential(
nn.Conv1d(hidden_channels, hidden_channels, 1),
nn.ReLU()
)
dilations = 4 * [1, 2, 4] + [1]
num_conv_blocks = 2
num_res_blocks = 13 # total 2 * 13 blocks
self.encoder = ResidualConvBNBlock(hidden_channels,
kernel_size=4,
dilations=dilations,
num_res_blocks=num_res_blocks,
num_conv_blocks=num_conv_blocks)
self.encoder = ResidualConvBNBlock(hidden_channels, **encoder_params)
self.postnet = nn.Sequential(
nn.Conv1d(self.hidden_channels, self.hidden_channels, 1),
nn.BatchNorm1d(self.hidden_channels))
elif encoder_type.lower() == 'time_depth_separable':
# optional convolutional prenet
if use_prenet:
self.pre = ConvLayerNorm(hidden_channels,
self.prenet = ConvLayerNorm(hidden_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
@ -145,8 +128,7 @@ class Encoder(nn.Module):
self.encoder = TimeDepthSeparableConvBlock(hidden_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
num_layers=3 + num_layers)
**encoder_params)
else:
raise ValueError(" [!] Unkown encoder type.")
@ -157,7 +139,7 @@ class Encoder(nn.Module):
# duration predictor
self.duration_predictor = DurationPredictor(
hidden_channels + c_in_channels, hidden_channels_dp, 3,
dropout_p)
dropout_p_dp)
def forward(self, x, x_lengths, g=None):
# embedding layer
@ -168,11 +150,14 @@ class Encoder(nn.Module):
# compute input sequence mask
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)),
1).to(x.dtype)
# pre-conv layers
if hasattr(self, 'pre') and self.use_prenet:
x = self.pre(x, x_mask)
# prenet
if hasattr(self, 'prenet') and self.use_prenet:
x = self.prenet(x, x_mask)
# encoder
x = self.encoder(x, x_mask)
# postnet
if hasattr(self, 'postnet'):
x = self.postnet(x) * x_mask
# set duration predictor input
if g is not None:
g_exp = g.expand(-1, -1, x.size(-1))

View File

@ -1,6 +1,7 @@
import torch
from torch import nn
from torch.nn import functional as F
from TTS.tts.layers.generic.wavenet import WN
from ..generic.normalization import LayerNorm
@ -50,104 +51,6 @@ class ConvLayerNorm(nn.Module):
return x * x_mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
class WN(torch.nn.Module):
def __init__(self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
c_in_channels=0,
dropout_p=0):
super().__init__()
assert kernel_size % 2 == 1
assert hidden_channels % 2 == 0
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.num_layers = num_layers
self.c_in_channels = c_in_channels
self.dropout_p = dropout_p
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.dropout = nn.Dropout(dropout_p)
if c_in_channels > 0:
cond_layer = torch.nn.Conv1d(c_in_channels,
2 * hidden_channels * num_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer,
name='weight')
for i in range(num_layers):
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(hidden_channels,
2 * hidden_channels,
kernel_size,
dilation=dilation,
padding=padding)
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
self.in_layers.append(in_layer)
if i < num_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels,
res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer,
name='weight')
self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
if g is not None:
g = self.cond_layer(g)
for i in range(self.num_layers):
x_in = self.in_layers[i](x)
x_in = self.dropout(x_in)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:,
cond_offset:cond_offset + 2 * self.hidden_channels, :]
else:
g_l = torch.zeros_like(x_in)
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l,
n_channels_tensor)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.num_layers - 1:
x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask
output = output + res_skip_acts[:, self.hidden_channels:, :]
else:
output = output + res_skip_acts
return output * x_mask
def remove_weight_norm(self):
if self.c_in_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
class InvConvNear(nn.Module):
def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument
super().__init__()
@ -166,7 +69,7 @@ class InvConvNear(nn.Module):
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
"""Split the input into groups of size self.num_splits and
perform 1x1 convolution separately. Cast 1x1 conv operation
to 2d by reshaping the input for efficienty.
to 2d by reshaping the input for efficiency.
"""
b, c, t = x.size()

View File

@ -81,7 +81,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
# compute raw attention scores
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
self.k_channels)
# relative positional encoding
# relative positional encoding for scores
if self.rel_attn_window_size is not None:
assert t_s == t_t, "Relative attention is only available for self-attention."
# get relative key embeddings
@ -262,7 +262,7 @@ class FFN(nn.Module):
return x * x_mask
class Transformer(nn.Module):
class RelativePositionTransformer(nn.Module):
def __init__(self,
hidden_channels,
hidden_channels_ffn,

View File

@ -10,44 +10,59 @@ from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path
class GlowTts(nn.Module):
"""Glow TTS models from https://arxiv.org/abs/2005.11129"""
"""Glow TTS models from https://arxiv.org/abs/2005.11129
Args:
num_chars (int): number of embedding characters.
hidden_channels_enc (int): number of embedding and encoder channels.
hidden_channels_dec (int): number of decoder channels.
use_encoder_prenet (bool): enable/disable prenet for encoder. Prenet modules are hard-coded for each alternative encoder.
hidden_channels_dp (int): number of duration predictor channels.
out_channels (int): number of output channels. It should be equal to the number of spectrogram filter.
num_flow_blocks_dec (int): number of decoder blocks.
kernel_size_dec (int): decoder kernel size.
dilation_rate (int): rate to increase dilation by each layer in a decoder block.
num_block_layers (int): number of decoder layers in each decoder block.
dropout_p_dec (float): dropout rate for decoder.
num_speaker (int): number of speaker to define the size of speaker embedding layer.
c_in_channels (int): number of speaker embedding channels. It is set to 512 if embeddings are learned.
num_splits (int): number of split levels in inversible conv1x1 operation.
num_squeeze (int): number of squeeze levels. When squeezing channels increases and time steps reduces by the factor 'num_squeeze'.
sigmoid_scale (bool): enable/disable sigmoid scaling in decoder.
mean_only (bool): if True, encoder only computes mean value and uses constant variance for each time step.
encoder_type (str): encoder module type.
encoder_params (dict): encoder module parameters.
external_speaker_embedding_dim (int): channels of external speaker embedding vectors.
"""
def __init__(self,
num_chars,
hidden_channels,
hidden_channels_ffn,
hidden_channels_enc,
hidden_channels_dec,
use_encoder_prenet,
hidden_channels_dp,
out_channels,
num_heads=2,
num_layers_enc=6,
dropout_p=0.1,
num_flow_blocks_dec=12,
kernel_size_dec=5,
dilation_rate=5,
num_block_layers=4,
dropout_p_dec=0.,
dropout_p_dp=0.1,
dropout_p_dec=0.05,
num_speakers=0,
c_in_channels=0,
num_splits=4,
num_sqz=1,
num_squeeze=1,
sigmoid_scale=False,
rel_attn_window_size=None,
input_length=None,
mean_only=False,
hidden_channels_enc=None,
hidden_channels_dec=None,
use_encoder_prenet=False,
encoder_type="transformer",
encoder_params=None,
external_speaker_embedding_dim=None):
super().__init__()
self.num_chars = num_chars
self.hidden_channels = hidden_channels
self.hidden_channels_ffn = hidden_channels_ffn
self.hidden_channels_dp = hidden_channels_dp
self.hidden_channels_enc = hidden_channels_enc
self.hidden_channels_dec = hidden_channels_dec
self.out_channels = out_channels
self.num_heads = num_heads
self.num_layers_enc = num_layers_enc
self.dropout_p = dropout_p
self.num_flow_blocks_dec = num_flow_blocks_dec
self.kernel_size_dec = kernel_size_dec
self.dilation_rate = dilation_rate
@ -56,16 +71,14 @@ class GlowTts(nn.Module):
self.num_speakers = num_speakers
self.c_in_channels = c_in_channels
self.num_splits = num_splits
self.num_sqz = num_sqz
self.num_squeeze = num_squeeze
self.sigmoid_scale = sigmoid_scale
self.rel_attn_window_size = rel_attn_window_size
self.input_length = input_length
self.mean_only = mean_only
self.hidden_channels_enc = hidden_channels_enc
self.hidden_channels_dec = hidden_channels_dec
self.use_encoder_prenet = use_encoder_prenet
self.noise_scale = 0.66
self.length_scale = 1.
# model constants.
self.noise_scale = 0.33 # defines the noise variance applied to the random z vector at inference.
self.length_scale = 1. # scaler for the duration predictor. The larger it is, the slower the speech.
self.external_speaker_embedding_dim = external_speaker_embedding_dim
# if is a multispeaker and c_in_channels is 0, set to 256
@ -77,27 +90,24 @@ class GlowTts(nn.Module):
self.encoder = Encoder(num_chars,
out_channels=out_channels,
hidden_channels=hidden_channels,
hidden_channels_ffn=hidden_channels_ffn,
hidden_channels=hidden_channels_enc,
hidden_channels_dp=hidden_channels_dp,
encoder_type=encoder_type,
num_heads=num_heads,
num_layers=num_layers_enc,
dropout_p=dropout_p,
rel_attn_window_size=rel_attn_window_size,
encoder_params=encoder_params,
mean_only=mean_only,
use_prenet=use_encoder_prenet,
dropout_p_dp=dropout_p_dp,
c_in_channels=self.c_in_channels)
self.decoder = Decoder(out_channels,
hidden_channels_dec or hidden_channels,
hidden_channels_dec,
kernel_size_dec,
dilation_rate,
num_flow_blocks_dec,
num_block_layers,
dropout_p=dropout_p_dec,
num_splits=num_splits,
num_sqz=num_sqz,
num_squeeze=num_squeeze,
sigmoid_scale=sigmoid_scale,
c_in_channels=self.c_in_channels)
@ -140,7 +150,7 @@ class GlowTts(nn.Module):
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
x_lengths,
g=g)
# drop redisual frames wrt num_sqz and set y_lengths.
# drop redisual frames wrt num_squeeze and set y_lengths.
y, y_lengths, y_max_length, attn = self.preprocess(
y, y_lengths, y_max_length, None)
# create masks
@ -195,6 +205,7 @@ class GlowTts(nn.Module):
attn_mask.squeeze(1)).unsqueeze(1)
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
attn, o_mean, o_log_scale, x_mask)
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) *
self.noise_scale) * y_mask
# decoder pass
@ -204,11 +215,11 @@ class GlowTts(nn.Module):
def preprocess(self, y, y_lengths, y_max_length, attn=None):
if y_max_length is not None:
y_max_length = (y_max_length // self.num_sqz) * self.num_sqz
y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze
y = y[:, :, :y_max_length]
if attn is not None:
attn = attn[:, :, :, :y_max_length]
y_lengths = (y_lengths // self.num_sqz) * self.num_sqz
y_lengths = (y_lengths // self.num_squeeze) * self.num_squeeze
return y, y_lengths, y_max_length, attn
def store_inverse(self):

View File

@ -103,15 +103,12 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
speaker_embedding_dim=speaker_embedding_dim)
elif c.model.lower() == "glow_tts":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
hidden_channels=192,
hidden_channels_ffn=768,
hidden_channels_dp=256,
hidden_channels_enc=c['hidden_channels_encoder'],
hidden_channels_dec=c['hidden_channels_decoder'],
hidden_channels_dp=c['hidden_channels_duration_predictor'],
out_channels=c.audio['num_mels'],
num_heads=2,
num_layers_enc=6,
encoder_type=c.encoder_type,
rel_attn_window_size=4,
dropout_p=0.1,
encoder_params=c.encoder_params,
num_flow_blocks_dec=12,
kernel_size_dec=5,
dilation_rate=1,