From e82d31b6ac03f20119495559ea09a5a4ed9d7f7e Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 5 Jan 2021 14:29:45 +0100 Subject: [PATCH] glow ttss refactoring --- TTS/tts/layers/glow_tts/encoder.py | 55 +++++--------- TTS/tts/layers/glow_tts/glow.py | 101 +------------------------ TTS/tts/layers/glow_tts/transformer.py | 4 +- TTS/tts/models/glow_tts.py | 83 +++++++++++--------- TTS/tts/utils/generic_utils.py | 11 +-- 5 files changed, 75 insertions(+), 179 deletions(-) diff --git a/TTS/tts/layers/glow_tts/encoder.py b/TTS/tts/layers/glow_tts/encoder.py index f4f80ee5..6a5c2fad 100644 --- a/TTS/tts/layers/glow_tts/encoder.py +++ b/TTS/tts/layers/glow_tts/encoder.py @@ -94,49 +94,32 @@ class Encoder(nn.Module): # embedding layer self.emb = nn.Embedding(num_chars, hidden_channels) nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) - # init encoder - if encoder_type.lower() == "transformer": - # optional convolutional prenet + # init encoder module + if encoder_type.lower() == "rel_pos_transformer": if use_prenet: - self.pre = ConvLayerNorm(hidden_channels, + self.prenet = ConvLayerNorm(hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5) - # text encoder - self.encoder = Transformer( - hidden_channels, - hidden_channels_ffn, - num_heads, - num_layers, - kernel_size=3, - dropout_p=dropout_p, - rel_attn_window_size=rel_attn_window_size, - input_length=input_length) + self.encoder = RelativePositionTransformer( + hidden_channels, **encoder_params) elif encoder_type.lower() == 'gated_conv': - self.encoder = GatedConvBlock(hidden_channels, - kernel_size=5, - dropout_p=dropout_p, - num_layers=3 + num_layers) + self.encoder = GatedConvBlock(hidden_channels, **encoder_params) elif encoder_type.lower() == 'residual_conv_bn': if use_prenet: - self.pre = nn.Sequential( + self.prenet = nn.Sequential( nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU() ) - dilations = 4 * [1, 2, 4] + [1] - num_conv_blocks = 2 - num_res_blocks = 13 # total 2 * 13 blocks - self.encoder = ResidualConvBNBlock(hidden_channels, - kernel_size=4, - dilations=dilations, - num_res_blocks=num_res_blocks, - num_conv_blocks=num_conv_blocks) + self.encoder = ResidualConvBNBlock(hidden_channels, **encoder_params) + self.postnet = nn.Sequential( + nn.Conv1d(self.hidden_channels, self.hidden_channels, 1), + nn.BatchNorm1d(self.hidden_channels)) elif encoder_type.lower() == 'time_depth_separable': - # optional convolutional prenet if use_prenet: - self.pre = ConvLayerNorm(hidden_channels, + self.prenet = ConvLayerNorm(hidden_channels, hidden_channels, hidden_channels, kernel_size=5, @@ -145,8 +128,7 @@ class Encoder(nn.Module): self.encoder = TimeDepthSeparableConvBlock(hidden_channels, hidden_channels, hidden_channels, - kernel_size=5, - num_layers=3 + num_layers) + **encoder_params) else: raise ValueError(" [!] Unkown encoder type.") @@ -157,7 +139,7 @@ class Encoder(nn.Module): # duration predictor self.duration_predictor = DurationPredictor( hidden_channels + c_in_channels, hidden_channels_dp, 3, - dropout_p) + dropout_p_dp) def forward(self, x, x_lengths, g=None): # embedding layer @@ -168,11 +150,14 @@ class Encoder(nn.Module): # compute input sequence mask x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) - # pre-conv layers - if hasattr(self, 'pre') and self.use_prenet: - x = self.pre(x, x_mask) + # prenet + if hasattr(self, 'prenet') and self.use_prenet: + x = self.prenet(x, x_mask) # encoder x = self.encoder(x, x_mask) + # postnet + if hasattr(self, 'postnet'): + x = self.postnet(x) * x_mask # set duration predictor input if g is not None: g_exp = g.expand(-1, -1, x.size(-1)) diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py index eba593dc..2a2ff189 100644 --- a/TTS/tts/layers/glow_tts/glow.py +++ b/TTS/tts/layers/glow_tts/glow.py @@ -1,6 +1,7 @@ import torch from torch import nn from torch.nn import functional as F +from TTS.tts.layers.generic.wavenet import WN from ..generic.normalization import LayerNorm @@ -50,104 +51,6 @@ class ConvLayerNorm(nn.Module): return x * x_mask -@torch.jit.script -def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): - n_channels_int = n_channels[0] - in_act = input_a + input_b - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) - acts = t_act * s_act - return acts - - -class WN(torch.nn.Module): - def __init__(self, - in_channels, - hidden_channels, - kernel_size, - dilation_rate, - num_layers, - c_in_channels=0, - dropout_p=0): - super().__init__() - assert kernel_size % 2 == 1 - assert hidden_channels % 2 == 0 - self.in_channels = in_channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.num_layers = num_layers - self.c_in_channels = c_in_channels - self.dropout_p = dropout_p - - self.in_layers = torch.nn.ModuleList() - self.res_skip_layers = torch.nn.ModuleList() - self.dropout = nn.Dropout(dropout_p) - - if c_in_channels > 0: - cond_layer = torch.nn.Conv1d(c_in_channels, - 2 * hidden_channels * num_layers, 1) - self.cond_layer = torch.nn.utils.weight_norm(cond_layer, - name='weight') - - for i in range(num_layers): - dilation = dilation_rate**i - padding = int((kernel_size * dilation - dilation) / 2) - in_layer = torch.nn.Conv1d(hidden_channels, - 2 * hidden_channels, - kernel_size, - dilation=dilation, - padding=padding) - in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') - self.in_layers.append(in_layer) - - if i < num_layers - 1: - res_skip_channels = 2 * hidden_channels - else: - res_skip_channels = hidden_channels - - res_skip_layer = torch.nn.Conv1d(hidden_channels, - res_skip_channels, 1) - res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, - name='weight') - self.res_skip_layers.append(res_skip_layer) - - def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument - output = torch.zeros_like(x) - n_channels_tensor = torch.IntTensor([self.hidden_channels]) - - if g is not None: - g = self.cond_layer(g) - - for i in range(self.num_layers): - x_in = self.in_layers[i](x) - x_in = self.dropout(x_in) - if g is not None: - cond_offset = i * 2 * self.hidden_channels - g_l = g[:, - cond_offset:cond_offset + 2 * self.hidden_channels, :] - else: - g_l = torch.zeros_like(x_in) - - acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, - n_channels_tensor) - - res_skip_acts = self.res_skip_layers[i](acts) - if i < self.num_layers - 1: - x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask - output = output + res_skip_acts[:, self.hidden_channels:, :] - else: - output = output + res_skip_acts - return output * x_mask - - def remove_weight_norm(self): - if self.c_in_channels != 0: - torch.nn.utils.remove_weight_norm(self.cond_layer) - for l in self.in_layers: - torch.nn.utils.remove_weight_norm(l) - for l in self.res_skip_layers: - torch.nn.utils.remove_weight_norm(l) - class InvConvNear(nn.Module): def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument super().__init__() @@ -166,7 +69,7 @@ class InvConvNear(nn.Module): def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument """Split the input into groups of size self.num_splits and perform 1x1 convolution separately. Cast 1x1 conv operation - to 2d by reshaping the input for efficienty. + to 2d by reshaping the input for efficiency. """ b, c, t = x.size() diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py index cc61c760..291c845a 100644 --- a/TTS/tts/layers/glow_tts/transformer.py +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -81,7 +81,7 @@ class RelativePositionMultiHeadAttention(nn.Module): # compute raw attention scores scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt( self.k_channels) - # relative positional encoding + # relative positional encoding for scores if self.rel_attn_window_size is not None: assert t_s == t_t, "Relative attention is only available for self-attention." # get relative key embeddings @@ -262,7 +262,7 @@ class FFN(nn.Module): return x * x_mask -class Transformer(nn.Module): +class RelativePositionTransformer(nn.Module): def __init__(self, hidden_channels, hidden_channels_ffn, diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 9f157ded..af5979dd 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -10,44 +10,59 @@ from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path class GlowTts(nn.Module): - """Glow TTS models from https://arxiv.org/abs/2005.11129""" + """Glow TTS models from https://arxiv.org/abs/2005.11129 + + Args: + num_chars (int): number of embedding characters. + hidden_channels_enc (int): number of embedding and encoder channels. + hidden_channels_dec (int): number of decoder channels. + use_encoder_prenet (bool): enable/disable prenet for encoder. Prenet modules are hard-coded for each alternative encoder. + hidden_channels_dp (int): number of duration predictor channels. + out_channels (int): number of output channels. It should be equal to the number of spectrogram filter. + num_flow_blocks_dec (int): number of decoder blocks. + kernel_size_dec (int): decoder kernel size. + dilation_rate (int): rate to increase dilation by each layer in a decoder block. + num_block_layers (int): number of decoder layers in each decoder block. + dropout_p_dec (float): dropout rate for decoder. + num_speaker (int): number of speaker to define the size of speaker embedding layer. + c_in_channels (int): number of speaker embedding channels. It is set to 512 if embeddings are learned. + num_splits (int): number of split levels in inversible conv1x1 operation. + num_squeeze (int): number of squeeze levels. When squeezing channels increases and time steps reduces by the factor 'num_squeeze'. + sigmoid_scale (bool): enable/disable sigmoid scaling in decoder. + mean_only (bool): if True, encoder only computes mean value and uses constant variance for each time step. + encoder_type (str): encoder module type. + encoder_params (dict): encoder module parameters. + external_speaker_embedding_dim (int): channels of external speaker embedding vectors. + """ def __init__(self, num_chars, - hidden_channels, - hidden_channels_ffn, + hidden_channels_enc, + hidden_channels_dec, + use_encoder_prenet, hidden_channels_dp, out_channels, - num_heads=2, - num_layers_enc=6, - dropout_p=0.1, num_flow_blocks_dec=12, kernel_size_dec=5, dilation_rate=5, num_block_layers=4, - dropout_p_dec=0., + dropout_p_dp=0.1, + dropout_p_dec=0.05, num_speakers=0, c_in_channels=0, num_splits=4, - num_sqz=1, + num_squeeze=1, sigmoid_scale=False, - rel_attn_window_size=None, - input_length=None, mean_only=False, - hidden_channels_enc=None, - hidden_channels_dec=None, - use_encoder_prenet=False, encoder_type="transformer", + encoder_params=None, external_speaker_embedding_dim=None): super().__init__() self.num_chars = num_chars - self.hidden_channels = hidden_channels - self.hidden_channels_ffn = hidden_channels_ffn self.hidden_channels_dp = hidden_channels_dp + self.hidden_channels_enc = hidden_channels_enc + self.hidden_channels_dec = hidden_channels_dec self.out_channels = out_channels - self.num_heads = num_heads - self.num_layers_enc = num_layers_enc - self.dropout_p = dropout_p self.num_flow_blocks_dec = num_flow_blocks_dec self.kernel_size_dec = kernel_size_dec self.dilation_rate = dilation_rate @@ -56,16 +71,14 @@ class GlowTts(nn.Module): self.num_speakers = num_speakers self.c_in_channels = c_in_channels self.num_splits = num_splits - self.num_sqz = num_sqz + self.num_squeeze = num_squeeze self.sigmoid_scale = sigmoid_scale - self.rel_attn_window_size = rel_attn_window_size - self.input_length = input_length self.mean_only = mean_only - self.hidden_channels_enc = hidden_channels_enc - self.hidden_channels_dec = hidden_channels_dec self.use_encoder_prenet = use_encoder_prenet - self.noise_scale = 0.66 - self.length_scale = 1. + + # model constants. + self.noise_scale = 0.33 # defines the noise variance applied to the random z vector at inference. + self.length_scale = 1. # scaler for the duration predictor. The larger it is, the slower the speech. self.external_speaker_embedding_dim = external_speaker_embedding_dim # if is a multispeaker and c_in_channels is 0, set to 256 @@ -77,27 +90,24 @@ class GlowTts(nn.Module): self.encoder = Encoder(num_chars, out_channels=out_channels, - hidden_channels=hidden_channels, - hidden_channels_ffn=hidden_channels_ffn, + hidden_channels=hidden_channels_enc, hidden_channels_dp=hidden_channels_dp, encoder_type=encoder_type, - num_heads=num_heads, - num_layers=num_layers_enc, - dropout_p=dropout_p, - rel_attn_window_size=rel_attn_window_size, + encoder_params=encoder_params, mean_only=mean_only, use_prenet=use_encoder_prenet, + dropout_p_dp=dropout_p_dp, c_in_channels=self.c_in_channels) self.decoder = Decoder(out_channels, - hidden_channels_dec or hidden_channels, + hidden_channels_dec, kernel_size_dec, dilation_rate, num_flow_blocks_dec, num_block_layers, dropout_p=dropout_p_dec, num_splits=num_splits, - num_sqz=num_sqz, + num_squeeze=num_squeeze, sigmoid_scale=sigmoid_scale, c_in_channels=self.c_in_channels) @@ -140,7 +150,7 @@ class GlowTts(nn.Module): o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) - # drop redisual frames wrt num_sqz and set y_lengths. + # drop redisual frames wrt num_squeeze and set y_lengths. y, y_lengths, y_max_length, attn = self.preprocess( y, y_lengths, y_max_length, None) # create masks @@ -195,6 +205,7 @@ class GlowTts(nn.Module): attn_mask.squeeze(1)).unsqueeze(1) y_mean, y_log_scale, o_attn_dur = self.compute_outputs( attn, o_mean, o_log_scale, x_mask) + z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.noise_scale) * y_mask # decoder pass @@ -204,11 +215,11 @@ class GlowTts(nn.Module): def preprocess(self, y, y_lengths, y_max_length, attn=None): if y_max_length is not None: - y_max_length = (y_max_length // self.num_sqz) * self.num_sqz + y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze y = y[:, :, :y_max_length] if attn is not None: attn = attn[:, :, :, :y_max_length] - y_lengths = (y_lengths // self.num_sqz) * self.num_sqz + y_lengths = (y_lengths // self.num_squeeze) * self.num_squeeze return y, y_lengths, y_max_length, attn def store_inverse(self): diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 568fc8c6..19ce7a16 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -103,15 +103,12 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): speaker_embedding_dim=speaker_embedding_dim) elif c.model.lower() == "glow_tts": model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), - hidden_channels=192, - hidden_channels_ffn=768, - hidden_channels_dp=256, + hidden_channels_enc=c['hidden_channels_encoder'], + hidden_channels_dec=c['hidden_channels_decoder'], + hidden_channels_dp=c['hidden_channels_duration_predictor'], out_channels=c.audio['num_mels'], - num_heads=2, - num_layers_enc=6, encoder_type=c.encoder_type, - rel_attn_window_size=4, - dropout_p=0.1, + encoder_params=c.encoder_params, num_flow_blocks_dec=12, kernel_size_dec=5, dilation_rate=1,