mirror of https://github.com/coqui-ai/TTS.git
cladd renaming
This commit is contained in:
parent
c0a2aa68d3
commit
1d961d6f8a
|
@ -5,10 +5,10 @@ from torch import nn
|
||||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||||
from TTS.tts.layers.generic.gated_conv import GatedConvBlock
|
from TTS.tts.layers.generic.gated_conv import GatedConvBlock
|
||||||
from TTS.tts.utils.generic_utils import sequence_mask
|
from TTS.tts.utils.generic_utils import sequence_mask
|
||||||
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm
|
from TTS.tts.layers.glow_tts.glow import ResidualConv1dLayerNormBlock
|
||||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.layers.generic.time_depth_sep_conv import TimeDepthSeparableConvBlock
|
from TTS.tts.layers.generic.time_depth_sep_conv import TimeDepthSeparableConvBlock
|
||||||
from TTS.tts.layers.generic.res_conv_bn import ResidualConvBNBlock
|
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
|
@ -97,14 +97,16 @@ class Encoder(nn.Module):
|
||||||
# init encoder module
|
# init encoder module
|
||||||
if encoder_type.lower() == "rel_pos_transformer":
|
if encoder_type.lower() == "rel_pos_transformer":
|
||||||
if use_prenet:
|
if use_prenet:
|
||||||
self.prenet = ConvLayerNorm(hidden_channels,
|
self.prenet = ResidualConv1dLayerNormBlock(hidden_channels,
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
kernel_size=5,
|
kernel_size=5,
|
||||||
num_layers=3,
|
num_layers=3,
|
||||||
dropout_p=0.5)
|
dropout_p=0.5)
|
||||||
self.encoder = RelativePositionTransformer(
|
self.encoder = RelativePositionTransformer(hidden_channels,
|
||||||
hidden_channels, **encoder_params)
|
hidden_channels,
|
||||||
|
hidden_channels,
|
||||||
|
**encoder_params)
|
||||||
elif encoder_type.lower() == 'gated_conv':
|
elif encoder_type.lower() == 'gated_conv':
|
||||||
self.encoder = GatedConvBlock(hidden_channels, **encoder_params)
|
self.encoder = GatedConvBlock(hidden_channels, **encoder_params)
|
||||||
elif encoder_type.lower() == 'residual_conv_bn':
|
elif encoder_type.lower() == 'residual_conv_bn':
|
||||||
|
@ -113,13 +115,16 @@ class Encoder(nn.Module):
|
||||||
nn.Conv1d(hidden_channels, hidden_channels, 1),
|
nn.Conv1d(hidden_channels, hidden_channels, 1),
|
||||||
nn.ReLU()
|
nn.ReLU()
|
||||||
)
|
)
|
||||||
self.encoder = ResidualConvBNBlock(hidden_channels, **encoder_params)
|
self.encoder = ResidualConv1dBNBlock(hidden_channels,
|
||||||
|
hidden_channels,
|
||||||
|
hidden_channels,
|
||||||
|
**encoder_params)
|
||||||
self.postnet = nn.Sequential(
|
self.postnet = nn.Sequential(
|
||||||
nn.Conv1d(self.hidden_channels, self.hidden_channels, 1),
|
nn.Conv1d(self.hidden_channels, self.hidden_channels, 1),
|
||||||
nn.BatchNorm1d(self.hidden_channels))
|
nn.BatchNorm1d(self.hidden_channels))
|
||||||
elif encoder_type.lower() == 'time_depth_separable':
|
elif encoder_type.lower() == 'time_depth_separable':
|
||||||
if use_prenet:
|
if use_prenet:
|
||||||
self.prenet = ConvLayerNorm(hidden_channels,
|
self.prenet = ResidualConv1dLayerNormBlock(hidden_channels,
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
kernel_size=5,
|
kernel_size=5,
|
||||||
|
|
|
@ -264,7 +264,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
||||||
return diff.unsqueeze(0).unsqueeze(0)
|
return diff.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
class FFN(nn.Module):
|
class FeedForwardNetwork(nn.Module):
|
||||||
"""Feed Forward Inner layers for Transformer.
|
"""Feed Forward Inner layers for Transformer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -326,6 +326,8 @@ class RelativePositionTransformer(nn.Module):
|
||||||
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
|
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
hidden_channels_ffn,
|
hidden_channels_ffn,
|
||||||
num_heads,
|
num_heads,
|
||||||
|
@ -348,23 +350,31 @@ class RelativePositionTransformer(nn.Module):
|
||||||
self.norm_layers_1 = nn.ModuleList()
|
self.norm_layers_1 = nn.ModuleList()
|
||||||
self.ffn_layers = nn.ModuleList()
|
self.ffn_layers = nn.ModuleList()
|
||||||
self.norm_layers_2 = nn.ModuleList()
|
self.norm_layers_2 = nn.ModuleList()
|
||||||
for _ in range(self.num_layers):
|
|
||||||
|
for idx in range(self.num_layers):
|
||||||
self.attn_layers.append(
|
self.attn_layers.append(
|
||||||
RelativePositionMultiHeadAttention(
|
RelativePositionMultiHeadAttention(
|
||||||
hidden_channels,
|
hidden_channels if idx != 0 else in_channels,
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
num_heads,
|
num_heads,
|
||||||
rel_attn_window_size=rel_attn_window_size,
|
rel_attn_window_size=rel_attn_window_size,
|
||||||
dropout_p=dropout_p,
|
dropout_p=dropout_p,
|
||||||
input_length=input_length))
|
input_length=input_length))
|
||||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||||
|
|
||||||
|
if hidden_channels != out_channels and (idx + 1) == self.num_layers:
|
||||||
|
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||||
|
|
||||||
self.ffn_layers.append(
|
self.ffn_layers.append(
|
||||||
FFN(hidden_channels,
|
FeedForwardNetwork(hidden_channels,
|
||||||
hidden_channels,
|
hidden_channels if (idx + 1) != self.num_layers else out_channels,
|
||||||
hidden_channels_ffn,
|
hidden_channels_ffn,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
dropout_p=dropout_p))
|
dropout_p=dropout_p))
|
||||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
|
||||||
|
self.norm_layers_2.append(
|
||||||
|
LayerNorm(hidden_channels if (
|
||||||
|
idx + 1) != self.num_layers else out_channels))
|
||||||
|
|
||||||
def forward(self, x, x_mask):
|
def forward(self, x, x_mask):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue