cladd renaming

This commit is contained in:
erogol 2021-01-11 17:26:11 +01:00
parent c0a2aa68d3
commit 1d961d6f8a
2 changed files with 28 additions and 13 deletions

View File

@ -5,10 +5,10 @@ from torch import nn
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.layers.generic.gated_conv import GatedConvBlock
from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm
from TTS.tts.layers.glow_tts.glow import ResidualConv1dLayerNormBlock
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.generic.time_depth_sep_conv import TimeDepthSeparableConvBlock
from TTS.tts.layers.generic.res_conv_bn import ResidualConvBNBlock
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
class Encoder(nn.Module):
@ -97,14 +97,16 @@ class Encoder(nn.Module):
# init encoder module
if encoder_type.lower() == "rel_pos_transformer":
if use_prenet:
self.prenet = ConvLayerNorm(hidden_channels,
self.prenet = ResidualConv1dLayerNormBlock(hidden_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
num_layers=3,
dropout_p=0.5)
self.encoder = RelativePositionTransformer(
hidden_channels, **encoder_params)
self.encoder = RelativePositionTransformer(hidden_channels,
hidden_channels,
hidden_channels,
**encoder_params)
elif encoder_type.lower() == 'gated_conv':
self.encoder = GatedConvBlock(hidden_channels, **encoder_params)
elif encoder_type.lower() == 'residual_conv_bn':
@ -113,13 +115,16 @@ class Encoder(nn.Module):
nn.Conv1d(hidden_channels, hidden_channels, 1),
nn.ReLU()
)
self.encoder = ResidualConvBNBlock(hidden_channels, **encoder_params)
self.encoder = ResidualConv1dBNBlock(hidden_channels,
hidden_channels,
hidden_channels,
**encoder_params)
self.postnet = nn.Sequential(
nn.Conv1d(self.hidden_channels, self.hidden_channels, 1),
nn.BatchNorm1d(self.hidden_channels))
elif encoder_type.lower() == 'time_depth_separable':
if use_prenet:
self.prenet = ConvLayerNorm(hidden_channels,
self.prenet = ResidualConv1dLayerNormBlock(hidden_channels,
hidden_channels,
hidden_channels,
kernel_size=5,

View File

@ -264,7 +264,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
return diff.unsqueeze(0).unsqueeze(0)
class FFN(nn.Module):
class FeedForwardNetwork(nn.Module):
"""Feed Forward Inner layers for Transformer.
Args:
@ -326,6 +326,8 @@ class RelativePositionTransformer(nn.Module):
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
"""
def __init__(self,
in_channels,
out_channels,
hidden_channels,
hidden_channels_ffn,
num_heads,
@ -348,23 +350,31 @@ class RelativePositionTransformer(nn.Module):
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for _ in range(self.num_layers):
for idx in range(self.num_layers):
self.attn_layers.append(
RelativePositionMultiHeadAttention(
hidden_channels,
hidden_channels if idx != 0 else in_channels,
hidden_channels,
num_heads,
rel_attn_window_size=rel_attn_window_size,
dropout_p=dropout_p,
input_length=input_length))
self.norm_layers_1.append(LayerNorm(hidden_channels))
if hidden_channels != out_channels and (idx + 1) == self.num_layers:
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.ffn_layers.append(
FFN(hidden_channels,
hidden_channels,
FeedForwardNetwork(hidden_channels,
hidden_channels if (idx + 1) != self.num_layers else out_channels,
hidden_channels_ffn,
kernel_size,
dropout_p=dropout_p))
self.norm_layers_2.append(LayerNorm(hidden_channels))
self.norm_layers_2.append(
LayerNorm(hidden_channels if (
idx + 1) != self.num_layers else out_channels))
def forward(self, x, x_mask):
"""