mirror of https://github.com/coqui-ai/TTS.git
time seperable convolution encoder, huber loss for duration predictor
This commit is contained in:
parent
f1a75468c2
commit
3660c57f1e
|
@ -1,6 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from .normalization import LayerNorm
|
||||||
|
|
||||||
|
|
||||||
class DurationPredictor(nn.Module):
|
class DurationPredictor(nn.Module):
|
||||||
def __init__(self, in_channels, filter_channels, kernel_size, dropout_p):
|
def __init__(self, in_channels, filter_channels, kernel_size, dropout_p):
|
||||||
|
@ -16,12 +18,12 @@ class DurationPredictor(nn.Module):
|
||||||
filter_channels,
|
filter_channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
padding=kernel_size // 2)
|
padding=kernel_size // 2)
|
||||||
self.norm_1 = nn.GroupNorm(1, filter_channels)
|
self.norm_1 = LayerNorm(filter_channels)
|
||||||
self.conv_2 = nn.Conv1d(filter_channels,
|
self.conv_2 = nn.Conv1d(filter_channels,
|
||||||
filter_channels,
|
filter_channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
padding=kernel_size // 2)
|
padding=kernel_size // 2)
|
||||||
self.norm_2 = nn.GroupNorm(1, filter_channels)
|
self.norm_2 = LayerNorm(filter_channels)
|
||||||
# output layer
|
# output layer
|
||||||
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
||||||
|
|
||||||
|
|
|
@ -3,49 +3,11 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from TTS.tts.layers.glow_tts.transformer import Transformer
|
from TTS.tts.layers.glow_tts.transformer import Transformer
|
||||||
|
from TTS.tts.layers.glow_tts.gated_conv import GatedConvBlock
|
||||||
from TTS.tts.utils.generic_utils import sequence_mask
|
from TTS.tts.utils.generic_utils import sequence_mask
|
||||||
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm, LayerNorm
|
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm, LayerNorm
|
||||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
|
from TTS.tts.layers.glow_tts.time_depth_sep_conv import TimeDepthSeparableConvBlock
|
||||||
|
|
||||||
class GatedConvBlock(nn.Module):
|
|
||||||
"""Gated convolutional block as in https://arxiv.org/pdf/1612.08083.pdf
|
|
||||||
Args:
|
|
||||||
in_out_channels (int): number of input/output channels.
|
|
||||||
kernel_size (int): convolution kernel size.
|
|
||||||
dropout_p (float): dropout rate.
|
|
||||||
"""
|
|
||||||
def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers):
|
|
||||||
super().__init__()
|
|
||||||
# class arguments
|
|
||||||
self.dropout_p = dropout_p
|
|
||||||
self.num_layers = num_layers
|
|
||||||
# define layers
|
|
||||||
self.conv_layers = nn.ModuleList()
|
|
||||||
self.norm_layers = nn.ModuleList()
|
|
||||||
self.layers = nn.ModuleList()
|
|
||||||
for _ in range(num_layers):
|
|
||||||
self.conv_layers += [
|
|
||||||
nn.Conv1d(in_out_channels,
|
|
||||||
2 * in_out_channels,
|
|
||||||
kernel_size,
|
|
||||||
padding=kernel_size // 2)
|
|
||||||
]
|
|
||||||
self.norm_layers += [LayerNorm(2 * in_out_channels)]
|
|
||||||
|
|
||||||
def forward(self, x, x_mask):
|
|
||||||
o = x
|
|
||||||
res = x
|
|
||||||
for idx in range(self.num_layers):
|
|
||||||
o = nn.functional.dropout(o,
|
|
||||||
p=self.dropout_p,
|
|
||||||
training=self.training)
|
|
||||||
o = self.conv_layers[idx](o * x_mask)
|
|
||||||
o = self.norm_layers[idx](o)
|
|
||||||
o = nn.functional.glu(o, dim=1)
|
|
||||||
o = res + o
|
|
||||||
res = o
|
|
||||||
return o
|
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
|
@ -82,7 +44,7 @@ class Encoder(nn.Module):
|
||||||
rel_attn_window_size=None,
|
rel_attn_window_size=None,
|
||||||
input_length=None,
|
input_length=None,
|
||||||
mean_only=False,
|
mean_only=False,
|
||||||
use_prenet=False,
|
use_prenet=True,
|
||||||
c_in_channels=0):
|
c_in_channels=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# class arguments
|
# class arguments
|
||||||
|
@ -127,6 +89,21 @@ class Encoder(nn.Module):
|
||||||
kernel_size=5,
|
kernel_size=5,
|
||||||
dropout_p=dropout_p,
|
dropout_p=dropout_p,
|
||||||
num_layers=3 + num_layers)
|
num_layers=3 + num_layers)
|
||||||
|
elif encoder_type.lower() == 'time-depth-separable':
|
||||||
|
# optional convolutional prenet
|
||||||
|
if use_prenet:
|
||||||
|
self.pre = ConvLayerNorm(hidden_channels,
|
||||||
|
hidden_channels,
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size=5,
|
||||||
|
num_layers=3,
|
||||||
|
dropout_p=0.5)
|
||||||
|
self.encoder = TimeDepthSeparableConvBlock(hidden_channels,
|
||||||
|
hidden_channels,
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size=5,
|
||||||
|
num_layers=3 + num_layers)
|
||||||
|
|
||||||
# final projection layers
|
# final projection layers
|
||||||
self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1)
|
self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||||
if not mean_only:
|
if not mean_only:
|
||||||
|
@ -146,7 +123,7 @@ class Encoder(nn.Module):
|
||||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)),
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)),
|
||||||
1).to(x.dtype)
|
1).to(x.dtype)
|
||||||
# pre-conv layers
|
# pre-conv layers
|
||||||
if self.encoder_type == 'transformer':
|
if self.encoder_type in ['transformer', 'time-depth-separable']:
|
||||||
if self.use_prenet:
|
if self.use_prenet:
|
||||||
x = self.pre(x, x_mask)
|
x = self.pre(x, x_mask)
|
||||||
# encoder
|
# encoder
|
||||||
|
|
|
@ -2,31 +2,7 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from .normalization import LayerNorm
|
||||||
class LayerNorm(nn.Module):
|
|
||||||
def __init__(self, channels, eps=1e-4):
|
|
||||||
"""Layer norm for the 2nd dimension of the input.
|
|
||||||
Args:
|
|
||||||
channels (int): number of channels (2nd dimension) of the input.
|
|
||||||
eps (float): to prevent 0 division
|
|
||||||
|
|
||||||
Shapes:
|
|
||||||
- input: (B, C, T)
|
|
||||||
- output: (B, C, T)
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.channels = channels
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
self.gamma = nn.Parameter(torch.ones(1, channels, 1) * 0.1)
|
|
||||||
self.beta = nn.Parameter(torch.zeros(1, channels, 1))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
mean = torch.mean(x, 1, keepdim=True)
|
|
||||||
variance = torch.mean((x - mean)**2, 1, keepdim=True)
|
|
||||||
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
|
||||||
x = x * self.gamma + self.beta
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ConvLayerNorm(nn.Module):
|
class ConvLayerNorm(nn.Module):
|
||||||
|
@ -70,7 +46,7 @@ class ConvLayerNorm(nn.Module):
|
||||||
x = self.conv_layers[i](x * x_mask)
|
x = self.conv_layers[i](x * x_mask)
|
||||||
x = self.norm_layers[i](x * x_mask)
|
x = self.norm_layers[i](x * x_mask)
|
||||||
x = F.dropout(F.relu(x), self.dropout_p, training=self.training)
|
x = F.dropout(F.relu(x), self.dropout_p, training=self.training)
|
||||||
x = x_org + self.proj(x)
|
x = x_res + self.proj(x)
|
||||||
return x * x_mask
|
return x * x_mask
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -251,13 +251,19 @@ class GlowTTSLoss(torch.nn.Module):
|
||||||
super(GlowTTSLoss, self).__init__()
|
super(GlowTTSLoss, self).__init__()
|
||||||
self.constant_factor = 0.5 * math.log(2 * math.pi)
|
self.constant_factor = 0.5 * math.log(2 * math.pi)
|
||||||
|
|
||||||
def forward(self, z, means, scales, log_det, y_lengths, o_dur_log, o_attn_dur, x_lengths):
|
def forward(self, z, means, scales, log_det, y_lengths, o_dur_log,
|
||||||
|
o_attn_dur, x_lengths):
|
||||||
return_dict = {}
|
return_dict = {}
|
||||||
# flow loss
|
# flow loss - neg log likelihood
|
||||||
pz = torch.sum(scales) + 0.5 * torch.sum(torch.exp(-2 * scales) * (z - means)**2)
|
pz = torch.sum(scales) + 0.5 * torch.sum(
|
||||||
log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (torch.sum(y_lengths // 2) * 2 * 80)
|
torch.exp(-2 * scales) * (z - means)**2)
|
||||||
# duration loss
|
log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (
|
||||||
loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths)
|
torch.sum(y_lengths // 2) * 2 * z.shape[1])
|
||||||
|
# duration loss - MSE
|
||||||
|
# loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths)
|
||||||
|
# duration loss - huber loss
|
||||||
|
loss_dur = torch.nn.functional.smooth_l1_loss(
|
||||||
|
o_dur_log, o_attn_dur, reduction='sum') / torch.sum(x_lengths)
|
||||||
return_dict['loss'] = log_mle + loss_dur
|
return_dict['loss'] = log_mle + loss_dur
|
||||||
return_dict['log_mle'] = log_mle
|
return_dict['log_mle'] = log_mle
|
||||||
return_dict['loss_dur'] = loss_dur
|
return_dict['loss_dur'] = loss_dur
|
||||||
|
|
|
@ -79,6 +79,7 @@ class GlowTts(nn.Module):
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
dropout_p=dropout_p,
|
dropout_p=dropout_p,
|
||||||
mean_only=mean_only,
|
mean_only=mean_only,
|
||||||
|
use_prenet=use_encoder_prenet,
|
||||||
c_in_channels=c_in_channels)
|
c_in_channels=c_in_channels)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,6 @@ class ResidualStack(tf.keras.layers.Layer):
|
||||||
]
|
]
|
||||||
|
|
||||||
def call(self, x):
|
def call(self, x):
|
||||||
# breakpoint()
|
|
||||||
for block, shortcut in zip(self.blocks, self.shortcuts):
|
for block, shortcut in zip(self.blocks, self.shortcuts):
|
||||||
res = shortcut(x)
|
res = shortcut(x)
|
||||||
for layer in block:
|
for layer in block:
|
||||||
|
|
Loading…
Reference in New Issue