glow ttss refactoring

This commit is contained in:
erogol 2021-01-05 14:29:45 +01:00
parent 29f4329d7f
commit e82d31b6ac
5 changed files with 75 additions and 179 deletions

View File

@ -94,49 +94,32 @@ class Encoder(nn.Module):
# embedding layer # embedding layer
self.emb = nn.Embedding(num_chars, hidden_channels) self.emb = nn.Embedding(num_chars, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
# init encoder # init encoder module
if encoder_type.lower() == "transformer": if encoder_type.lower() == "rel_pos_transformer":
# optional convolutional prenet
if use_prenet: if use_prenet:
self.pre = ConvLayerNorm(hidden_channels, self.prenet = ConvLayerNorm(hidden_channels,
hidden_channels, hidden_channels,
hidden_channels, hidden_channels,
kernel_size=5, kernel_size=5,
num_layers=3, num_layers=3,
dropout_p=0.5) dropout_p=0.5)
# text encoder self.encoder = RelativePositionTransformer(
self.encoder = Transformer( hidden_channels, **encoder_params)
hidden_channels,
hidden_channels_ffn,
num_heads,
num_layers,
kernel_size=3,
dropout_p=dropout_p,
rel_attn_window_size=rel_attn_window_size,
input_length=input_length)
elif encoder_type.lower() == 'gated_conv': elif encoder_type.lower() == 'gated_conv':
self.encoder = GatedConvBlock(hidden_channels, self.encoder = GatedConvBlock(hidden_channels, **encoder_params)
kernel_size=5,
dropout_p=dropout_p,
num_layers=3 + num_layers)
elif encoder_type.lower() == 'residual_conv_bn': elif encoder_type.lower() == 'residual_conv_bn':
if use_prenet: if use_prenet:
self.pre = nn.Sequential( self.prenet = nn.Sequential(
nn.Conv1d(hidden_channels, hidden_channels, 1), nn.Conv1d(hidden_channels, hidden_channels, 1),
nn.ReLU() nn.ReLU()
) )
dilations = 4 * [1, 2, 4] + [1] self.encoder = ResidualConvBNBlock(hidden_channels, **encoder_params)
num_conv_blocks = 2 self.postnet = nn.Sequential(
num_res_blocks = 13 # total 2 * 13 blocks nn.Conv1d(self.hidden_channels, self.hidden_channels, 1),
self.encoder = ResidualConvBNBlock(hidden_channels, nn.BatchNorm1d(self.hidden_channels))
kernel_size=4,
dilations=dilations,
num_res_blocks=num_res_blocks,
num_conv_blocks=num_conv_blocks)
elif encoder_type.lower() == 'time_depth_separable': elif encoder_type.lower() == 'time_depth_separable':
# optional convolutional prenet
if use_prenet: if use_prenet:
self.pre = ConvLayerNorm(hidden_channels, self.prenet = ConvLayerNorm(hidden_channels,
hidden_channels, hidden_channels,
hidden_channels, hidden_channels,
kernel_size=5, kernel_size=5,
@ -145,8 +128,7 @@ class Encoder(nn.Module):
self.encoder = TimeDepthSeparableConvBlock(hidden_channels, self.encoder = TimeDepthSeparableConvBlock(hidden_channels,
hidden_channels, hidden_channels,
hidden_channels, hidden_channels,
kernel_size=5, **encoder_params)
num_layers=3 + num_layers)
else: else:
raise ValueError(" [!] Unkown encoder type.") raise ValueError(" [!] Unkown encoder type.")
@ -157,7 +139,7 @@ class Encoder(nn.Module):
# duration predictor # duration predictor
self.duration_predictor = DurationPredictor( self.duration_predictor = DurationPredictor(
hidden_channels + c_in_channels, hidden_channels_dp, 3, hidden_channels + c_in_channels, hidden_channels_dp, 3,
dropout_p) dropout_p_dp)
def forward(self, x, x_lengths, g=None): def forward(self, x, x_lengths, g=None):
# embedding layer # embedding layer
@ -168,11 +150,14 @@ class Encoder(nn.Module):
# compute input sequence mask # compute input sequence mask
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)),
1).to(x.dtype) 1).to(x.dtype)
# pre-conv layers # prenet
if hasattr(self, 'pre') and self.use_prenet: if hasattr(self, 'prenet') and self.use_prenet:
x = self.pre(x, x_mask) x = self.prenet(x, x_mask)
# encoder # encoder
x = self.encoder(x, x_mask) x = self.encoder(x, x_mask)
# postnet
if hasattr(self, 'postnet'):
x = self.postnet(x) * x_mask
# set duration predictor input # set duration predictor input
if g is not None: if g is not None:
g_exp = g.expand(-1, -1, x.size(-1)) g_exp = g.expand(-1, -1, x.size(-1))

View File

@ -1,6 +1,7 @@
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from TTS.tts.layers.generic.wavenet import WN
from ..generic.normalization import LayerNorm from ..generic.normalization import LayerNorm
@ -50,104 +51,6 @@ class ConvLayerNorm(nn.Module):
return x * x_mask return x * x_mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
class WN(torch.nn.Module):
def __init__(self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
c_in_channels=0,
dropout_p=0):
super().__init__()
assert kernel_size % 2 == 1
assert hidden_channels % 2 == 0
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.num_layers = num_layers
self.c_in_channels = c_in_channels
self.dropout_p = dropout_p
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.dropout = nn.Dropout(dropout_p)
if c_in_channels > 0:
cond_layer = torch.nn.Conv1d(c_in_channels,
2 * hidden_channels * num_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer,
name='weight')
for i in range(num_layers):
dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(hidden_channels,
2 * hidden_channels,
kernel_size,
dilation=dilation,
padding=padding)
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
self.in_layers.append(in_layer)
if i < num_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels,
res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer,
name='weight')
self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
if g is not None:
g = self.cond_layer(g)
for i in range(self.num_layers):
x_in = self.in_layers[i](x)
x_in = self.dropout(x_in)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:,
cond_offset:cond_offset + 2 * self.hidden_channels, :]
else:
g_l = torch.zeros_like(x_in)
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l,
n_channels_tensor)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.num_layers - 1:
x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask
output = output + res_skip_acts[:, self.hidden_channels:, :]
else:
output = output + res_skip_acts
return output * x_mask
def remove_weight_norm(self):
if self.c_in_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
class InvConvNear(nn.Module): class InvConvNear(nn.Module):
def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument
super().__init__() super().__init__()
@ -166,7 +69,7 @@ class InvConvNear(nn.Module):
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
"""Split the input into groups of size self.num_splits and """Split the input into groups of size self.num_splits and
perform 1x1 convolution separately. Cast 1x1 conv operation perform 1x1 convolution separately. Cast 1x1 conv operation
to 2d by reshaping the input for efficienty. to 2d by reshaping the input for efficiency.
""" """
b, c, t = x.size() b, c, t = x.size()

View File

@ -81,7 +81,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
# compute raw attention scores # compute raw attention scores
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt( scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
self.k_channels) self.k_channels)
# relative positional encoding # relative positional encoding for scores
if self.rel_attn_window_size is not None: if self.rel_attn_window_size is not None:
assert t_s == t_t, "Relative attention is only available for self-attention." assert t_s == t_t, "Relative attention is only available for self-attention."
# get relative key embeddings # get relative key embeddings
@ -262,7 +262,7 @@ class FFN(nn.Module):
return x * x_mask return x * x_mask
class Transformer(nn.Module): class RelativePositionTransformer(nn.Module):
def __init__(self, def __init__(self,
hidden_channels, hidden_channels,
hidden_channels_ffn, hidden_channels_ffn,

View File

@ -10,44 +10,59 @@ from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path
class GlowTts(nn.Module): class GlowTts(nn.Module):
"""Glow TTS models from https://arxiv.org/abs/2005.11129""" """Glow TTS models from https://arxiv.org/abs/2005.11129
Args:
num_chars (int): number of embedding characters.
hidden_channels_enc (int): number of embedding and encoder channels.
hidden_channels_dec (int): number of decoder channels.
use_encoder_prenet (bool): enable/disable prenet for encoder. Prenet modules are hard-coded for each alternative encoder.
hidden_channels_dp (int): number of duration predictor channels.
out_channels (int): number of output channels. It should be equal to the number of spectrogram filter.
num_flow_blocks_dec (int): number of decoder blocks.
kernel_size_dec (int): decoder kernel size.
dilation_rate (int): rate to increase dilation by each layer in a decoder block.
num_block_layers (int): number of decoder layers in each decoder block.
dropout_p_dec (float): dropout rate for decoder.
num_speaker (int): number of speaker to define the size of speaker embedding layer.
c_in_channels (int): number of speaker embedding channels. It is set to 512 if embeddings are learned.
num_splits (int): number of split levels in inversible conv1x1 operation.
num_squeeze (int): number of squeeze levels. When squeezing channels increases and time steps reduces by the factor 'num_squeeze'.
sigmoid_scale (bool): enable/disable sigmoid scaling in decoder.
mean_only (bool): if True, encoder only computes mean value and uses constant variance for each time step.
encoder_type (str): encoder module type.
encoder_params (dict): encoder module parameters.
external_speaker_embedding_dim (int): channels of external speaker embedding vectors.
"""
def __init__(self, def __init__(self,
num_chars, num_chars,
hidden_channels, hidden_channels_enc,
hidden_channels_ffn, hidden_channels_dec,
use_encoder_prenet,
hidden_channels_dp, hidden_channels_dp,
out_channels, out_channels,
num_heads=2,
num_layers_enc=6,
dropout_p=0.1,
num_flow_blocks_dec=12, num_flow_blocks_dec=12,
kernel_size_dec=5, kernel_size_dec=5,
dilation_rate=5, dilation_rate=5,
num_block_layers=4, num_block_layers=4,
dropout_p_dec=0., dropout_p_dp=0.1,
dropout_p_dec=0.05,
num_speakers=0, num_speakers=0,
c_in_channels=0, c_in_channels=0,
num_splits=4, num_splits=4,
num_sqz=1, num_squeeze=1,
sigmoid_scale=False, sigmoid_scale=False,
rel_attn_window_size=None,
input_length=None,
mean_only=False, mean_only=False,
hidden_channels_enc=None,
hidden_channels_dec=None,
use_encoder_prenet=False,
encoder_type="transformer", encoder_type="transformer",
encoder_params=None,
external_speaker_embedding_dim=None): external_speaker_embedding_dim=None):
super().__init__() super().__init__()
self.num_chars = num_chars self.num_chars = num_chars
self.hidden_channels = hidden_channels
self.hidden_channels_ffn = hidden_channels_ffn
self.hidden_channels_dp = hidden_channels_dp self.hidden_channels_dp = hidden_channels_dp
self.hidden_channels_enc = hidden_channels_enc
self.hidden_channels_dec = hidden_channels_dec
self.out_channels = out_channels self.out_channels = out_channels
self.num_heads = num_heads
self.num_layers_enc = num_layers_enc
self.dropout_p = dropout_p
self.num_flow_blocks_dec = num_flow_blocks_dec self.num_flow_blocks_dec = num_flow_blocks_dec
self.kernel_size_dec = kernel_size_dec self.kernel_size_dec = kernel_size_dec
self.dilation_rate = dilation_rate self.dilation_rate = dilation_rate
@ -56,16 +71,14 @@ class GlowTts(nn.Module):
self.num_speakers = num_speakers self.num_speakers = num_speakers
self.c_in_channels = c_in_channels self.c_in_channels = c_in_channels
self.num_splits = num_splits self.num_splits = num_splits
self.num_sqz = num_sqz self.num_squeeze = num_squeeze
self.sigmoid_scale = sigmoid_scale self.sigmoid_scale = sigmoid_scale
self.rel_attn_window_size = rel_attn_window_size
self.input_length = input_length
self.mean_only = mean_only self.mean_only = mean_only
self.hidden_channels_enc = hidden_channels_enc
self.hidden_channels_dec = hidden_channels_dec
self.use_encoder_prenet = use_encoder_prenet self.use_encoder_prenet = use_encoder_prenet
self.noise_scale = 0.66
self.length_scale = 1. # model constants.
self.noise_scale = 0.33 # defines the noise variance applied to the random z vector at inference.
self.length_scale = 1. # scaler for the duration predictor. The larger it is, the slower the speech.
self.external_speaker_embedding_dim = external_speaker_embedding_dim self.external_speaker_embedding_dim = external_speaker_embedding_dim
# if is a multispeaker and c_in_channels is 0, set to 256 # if is a multispeaker and c_in_channels is 0, set to 256
@ -77,27 +90,24 @@ class GlowTts(nn.Module):
self.encoder = Encoder(num_chars, self.encoder = Encoder(num_chars,
out_channels=out_channels, out_channels=out_channels,
hidden_channels=hidden_channels, hidden_channels=hidden_channels_enc,
hidden_channels_ffn=hidden_channels_ffn,
hidden_channels_dp=hidden_channels_dp, hidden_channels_dp=hidden_channels_dp,
encoder_type=encoder_type, encoder_type=encoder_type,
num_heads=num_heads, encoder_params=encoder_params,
num_layers=num_layers_enc,
dropout_p=dropout_p,
rel_attn_window_size=rel_attn_window_size,
mean_only=mean_only, mean_only=mean_only,
use_prenet=use_encoder_prenet, use_prenet=use_encoder_prenet,
dropout_p_dp=dropout_p_dp,
c_in_channels=self.c_in_channels) c_in_channels=self.c_in_channels)
self.decoder = Decoder(out_channels, self.decoder = Decoder(out_channels,
hidden_channels_dec or hidden_channels, hidden_channels_dec,
kernel_size_dec, kernel_size_dec,
dilation_rate, dilation_rate,
num_flow_blocks_dec, num_flow_blocks_dec,
num_block_layers, num_block_layers,
dropout_p=dropout_p_dec, dropout_p=dropout_p_dec,
num_splits=num_splits, num_splits=num_splits,
num_sqz=num_sqz, num_squeeze=num_squeeze,
sigmoid_scale=sigmoid_scale, sigmoid_scale=sigmoid_scale,
c_in_channels=self.c_in_channels) c_in_channels=self.c_in_channels)
@ -140,7 +150,7 @@ class GlowTts(nn.Module):
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
x_lengths, x_lengths,
g=g) g=g)
# drop redisual frames wrt num_sqz and set y_lengths. # drop redisual frames wrt num_squeeze and set y_lengths.
y, y_lengths, y_max_length, attn = self.preprocess( y, y_lengths, y_max_length, attn = self.preprocess(
y, y_lengths, y_max_length, None) y, y_lengths, y_max_length, None)
# create masks # create masks
@ -195,6 +205,7 @@ class GlowTts(nn.Module):
attn_mask.squeeze(1)).unsqueeze(1) attn_mask.squeeze(1)).unsqueeze(1)
y_mean, y_log_scale, o_attn_dur = self.compute_outputs( y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
attn, o_mean, o_log_scale, x_mask) attn, o_mean, o_log_scale, x_mask)
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) *
self.noise_scale) * y_mask self.noise_scale) * y_mask
# decoder pass # decoder pass
@ -204,11 +215,11 @@ class GlowTts(nn.Module):
def preprocess(self, y, y_lengths, y_max_length, attn=None): def preprocess(self, y, y_lengths, y_max_length, attn=None):
if y_max_length is not None: if y_max_length is not None:
y_max_length = (y_max_length // self.num_sqz) * self.num_sqz y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze
y = y[:, :, :y_max_length] y = y[:, :, :y_max_length]
if attn is not None: if attn is not None:
attn = attn[:, :, :, :y_max_length] attn = attn[:, :, :, :y_max_length]
y_lengths = (y_lengths // self.num_sqz) * self.num_sqz y_lengths = (y_lengths // self.num_squeeze) * self.num_squeeze
return y, y_lengths, y_max_length, attn return y, y_lengths, y_max_length, attn
def store_inverse(self): def store_inverse(self):

View File

@ -103,15 +103,12 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
speaker_embedding_dim=speaker_embedding_dim) speaker_embedding_dim=speaker_embedding_dim)
elif c.model.lower() == "glow_tts": elif c.model.lower() == "glow_tts":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
hidden_channels=192, hidden_channels_enc=c['hidden_channels_encoder'],
hidden_channels_ffn=768, hidden_channels_dec=c['hidden_channels_decoder'],
hidden_channels_dp=256, hidden_channels_dp=c['hidden_channels_duration_predictor'],
out_channels=c.audio['num_mels'], out_channels=c.audio['num_mels'],
num_heads=2,
num_layers_enc=6,
encoder_type=c.encoder_type, encoder_type=c.encoder_type,
rel_attn_window_size=4, encoder_params=c.encoder_params,
dropout_p=0.1,
num_flow_blocks_dec=12, num_flow_blocks_dec=12,
kernel_size_dec=5, kernel_size_dec=5,
dilation_rate=1, dilation_rate=1,