mirror of https://github.com/coqui-ai/TTS.git
time seperable convolution encoder, huber loss for duration predictor
This commit is contained in:
parent
f1a75468c2
commit
3660c57f1e
|
@ -1,6 +1,8 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .normalization import LayerNorm
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, dropout_p):
|
||||
|
@ -16,12 +18,12 @@ class DurationPredictor(nn.Module):
|
|||
filter_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2)
|
||||
self.norm_1 = nn.GroupNorm(1, filter_channels)
|
||||
self.norm_1 = LayerNorm(filter_channels)
|
||||
self.conv_2 = nn.Conv1d(filter_channels,
|
||||
filter_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2)
|
||||
self.norm_2 = nn.GroupNorm(1, filter_channels)
|
||||
self.norm_2 = LayerNorm(filter_channels)
|
||||
# output layer
|
||||
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
||||
|
||||
|
|
|
@ -3,49 +3,11 @@ import torch
|
|||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.glow_tts.transformer import Transformer
|
||||
from TTS.tts.layers.glow_tts.gated_conv import GatedConvBlock
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm, LayerNorm
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
|
||||
|
||||
class GatedConvBlock(nn.Module):
|
||||
"""Gated convolutional block as in https://arxiv.org/pdf/1612.08083.pdf
|
||||
Args:
|
||||
in_out_channels (int): number of input/output channels.
|
||||
kernel_size (int): convolution kernel size.
|
||||
dropout_p (float): dropout rate.
|
||||
"""
|
||||
def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers):
|
||||
super().__init__()
|
||||
# class arguments
|
||||
self.dropout_p = dropout_p
|
||||
self.num_layers = num_layers
|
||||
# define layers
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.norm_layers = nn.ModuleList()
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(num_layers):
|
||||
self.conv_layers += [
|
||||
nn.Conv1d(in_out_channels,
|
||||
2 * in_out_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2)
|
||||
]
|
||||
self.norm_layers += [LayerNorm(2 * in_out_channels)]
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
o = x
|
||||
res = x
|
||||
for idx in range(self.num_layers):
|
||||
o = nn.functional.dropout(o,
|
||||
p=self.dropout_p,
|
||||
training=self.training)
|
||||
o = self.conv_layers[idx](o * x_mask)
|
||||
o = self.norm_layers[idx](o)
|
||||
o = nn.functional.glu(o, dim=1)
|
||||
o = res + o
|
||||
res = o
|
||||
return o
|
||||
from TTS.tts.layers.glow_tts.time_depth_sep_conv import TimeDepthSeparableConvBlock
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
@ -82,7 +44,7 @@ class Encoder(nn.Module):
|
|||
rel_attn_window_size=None,
|
||||
input_length=None,
|
||||
mean_only=False,
|
||||
use_prenet=False,
|
||||
use_prenet=True,
|
||||
c_in_channels=0):
|
||||
super().__init__()
|
||||
# class arguments
|
||||
|
@ -127,6 +89,21 @@ class Encoder(nn.Module):
|
|||
kernel_size=5,
|
||||
dropout_p=dropout_p,
|
||||
num_layers=3 + num_layers)
|
||||
elif encoder_type.lower() == 'time-depth-separable':
|
||||
# optional convolutional prenet
|
||||
if use_prenet:
|
||||
self.pre = ConvLayerNorm(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size=5,
|
||||
num_layers=3,
|
||||
dropout_p=0.5)
|
||||
self.encoder = TimeDepthSeparableConvBlock(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size=5,
|
||||
num_layers=3 + num_layers)
|
||||
|
||||
# final projection layers
|
||||
self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
if not mean_only:
|
||||
|
@ -146,7 +123,7 @@ class Encoder(nn.Module):
|
|||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)),
|
||||
1).to(x.dtype)
|
||||
# pre-conv layers
|
||||
if self.encoder_type == 'transformer':
|
||||
if self.encoder_type in ['transformer', 'time-depth-separable']:
|
||||
if self.use_prenet:
|
||||
x = self.pre(x, x_mask)
|
||||
# encoder
|
||||
|
|
|
@ -2,31 +2,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-4):
|
||||
"""Layer norm for the 2nd dimension of the input.
|
||||
Args:
|
||||
channels (int): number of channels (2nd dimension) of the input.
|
||||
eps (float): to prevent 0 division
|
||||
|
||||
Shapes:
|
||||
- input: (B, C, T)
|
||||
- output: (B, C, T)
|
||||
"""
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.gamma = nn.Parameter(torch.ones(1, channels, 1) * 0.1)
|
||||
self.beta = nn.Parameter(torch.zeros(1, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
mean = torch.mean(x, 1, keepdim=True)
|
||||
variance = torch.mean((x - mean)**2, 1, keepdim=True)
|
||||
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
||||
x = x * self.gamma + self.beta
|
||||
return x
|
||||
from .normalization import LayerNorm
|
||||
|
||||
|
||||
class ConvLayerNorm(nn.Module):
|
||||
|
@ -70,7 +46,7 @@ class ConvLayerNorm(nn.Module):
|
|||
x = self.conv_layers[i](x * x_mask)
|
||||
x = self.norm_layers[i](x * x_mask)
|
||||
x = F.dropout(F.relu(x), self.dropout_p, training=self.training)
|
||||
x = x_org + self.proj(x)
|
||||
x = x_res + self.proj(x)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
|
|
|
@ -251,13 +251,19 @@ class GlowTTSLoss(torch.nn.Module):
|
|||
super(GlowTTSLoss, self).__init__()
|
||||
self.constant_factor = 0.5 * math.log(2 * math.pi)
|
||||
|
||||
def forward(self, z, means, scales, log_det, y_lengths, o_dur_log, o_attn_dur, x_lengths):
|
||||
def forward(self, z, means, scales, log_det, y_lengths, o_dur_log,
|
||||
o_attn_dur, x_lengths):
|
||||
return_dict = {}
|
||||
# flow loss
|
||||
pz = torch.sum(scales) + 0.5 * torch.sum(torch.exp(-2 * scales) * (z - means)**2)
|
||||
log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (torch.sum(y_lengths // 2) * 2 * 80)
|
||||
# duration loss
|
||||
loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths)
|
||||
# flow loss - neg log likelihood
|
||||
pz = torch.sum(scales) + 0.5 * torch.sum(
|
||||
torch.exp(-2 * scales) * (z - means)**2)
|
||||
log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (
|
||||
torch.sum(y_lengths // 2) * 2 * z.shape[1])
|
||||
# duration loss - MSE
|
||||
# loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths)
|
||||
# duration loss - huber loss
|
||||
loss_dur = torch.nn.functional.smooth_l1_loss(
|
||||
o_dur_log, o_attn_dur, reduction='sum') / torch.sum(x_lengths)
|
||||
return_dict['loss'] = log_mle + loss_dur
|
||||
return_dict['log_mle'] = log_mle
|
||||
return_dict['loss_dur'] = loss_dur
|
||||
|
|
|
@ -79,6 +79,7 @@ class GlowTts(nn.Module):
|
|||
kernel_size=kernel_size,
|
||||
dropout_p=dropout_p,
|
||||
mean_only=mean_only,
|
||||
use_prenet=use_encoder_prenet,
|
||||
c_in_channels=c_in_channels)
|
||||
|
||||
|
||||
|
|
|
@ -48,7 +48,6 @@ class ResidualStack(tf.keras.layers.Layer):
|
|||
]
|
||||
|
||||
def call(self, x):
|
||||
# breakpoint()
|
||||
for block, shortcut in zip(self.blocks, self.shortcuts):
|
||||
res = shortcut(x)
|
||||
for layer in block:
|
||||
|
|
Loading…
Reference in New Issue