mirror of https://github.com/coqui-ai/TTS.git
cladd renaming
This commit is contained in:
parent
c0a2aa68d3
commit
1d961d6f8a
|
@ -5,10 +5,10 @@ from torch import nn
|
|||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||
from TTS.tts.layers.generic.gated_conv import GatedConvBlock
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm
|
||||
from TTS.tts.layers.glow_tts.glow import ResidualConv1dLayerNormBlock
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.generic.time_depth_sep_conv import TimeDepthSeparableConvBlock
|
||||
from TTS.tts.layers.generic.res_conv_bn import ResidualConvBNBlock
|
||||
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
@ -97,14 +97,16 @@ class Encoder(nn.Module):
|
|||
# init encoder module
|
||||
if encoder_type.lower() == "rel_pos_transformer":
|
||||
if use_prenet:
|
||||
self.prenet = ConvLayerNorm(hidden_channels,
|
||||
self.prenet = ResidualConv1dLayerNormBlock(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size=5,
|
||||
num_layers=3,
|
||||
dropout_p=0.5)
|
||||
self.encoder = RelativePositionTransformer(
|
||||
hidden_channels, **encoder_params)
|
||||
self.encoder = RelativePositionTransformer(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
**encoder_params)
|
||||
elif encoder_type.lower() == 'gated_conv':
|
||||
self.encoder = GatedConvBlock(hidden_channels, **encoder_params)
|
||||
elif encoder_type.lower() == 'residual_conv_bn':
|
||||
|
@ -113,13 +115,16 @@ class Encoder(nn.Module):
|
|||
nn.Conv1d(hidden_channels, hidden_channels, 1),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.encoder = ResidualConvBNBlock(hidden_channels, **encoder_params)
|
||||
self.encoder = ResidualConv1dBNBlock(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
**encoder_params)
|
||||
self.postnet = nn.Sequential(
|
||||
nn.Conv1d(self.hidden_channels, self.hidden_channels, 1),
|
||||
nn.BatchNorm1d(self.hidden_channels))
|
||||
elif encoder_type.lower() == 'time_depth_separable':
|
||||
if use_prenet:
|
||||
self.prenet = ConvLayerNorm(hidden_channels,
|
||||
self.prenet = ResidualConv1dLayerNormBlock(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size=5,
|
||||
|
|
|
@ -264,7 +264,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
return diff.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
class FeedForwardNetwork(nn.Module):
|
||||
"""Feed Forward Inner layers for Transformer.
|
||||
|
||||
Args:
|
||||
|
@ -326,6 +326,8 @@ class RelativePositionTransformer(nn.Module):
|
|||
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
hidden_channels_ffn,
|
||||
num_heads,
|
||||
|
@ -348,23 +350,31 @@ class RelativePositionTransformer(nn.Module):
|
|||
self.norm_layers_1 = nn.ModuleList()
|
||||
self.ffn_layers = nn.ModuleList()
|
||||
self.norm_layers_2 = nn.ModuleList()
|
||||
for _ in range(self.num_layers):
|
||||
|
||||
for idx in range(self.num_layers):
|
||||
self.attn_layers.append(
|
||||
RelativePositionMultiHeadAttention(
|
||||
hidden_channels,
|
||||
hidden_channels if idx != 0 else in_channels,
|
||||
hidden_channels,
|
||||
num_heads,
|
||||
rel_attn_window_size=rel_attn_window_size,
|
||||
dropout_p=dropout_p,
|
||||
input_length=input_length))
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
|
||||
if hidden_channels != out_channels and (idx + 1) == self.num_layers:
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
|
||||
self.ffn_layers.append(
|
||||
FFN(hidden_channels,
|
||||
hidden_channels,
|
||||
FeedForwardNetwork(hidden_channels,
|
||||
hidden_channels if (idx + 1) != self.num_layers else out_channels,
|
||||
hidden_channels_ffn,
|
||||
kernel_size,
|
||||
dropout_p=dropout_p))
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
|
||||
self.norm_layers_2.append(
|
||||
LayerNorm(hidden_channels if (
|
||||
idx + 1) != self.num_layers else out_channels))
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue