From 46728895492309f2a819bdd27e7607dfa68ae6ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 3 Sep 2021 13:27:56 +0000 Subject: [PATCH] Update `generic.FFTransformer` --- TTS/tts/layers/generic/transformer.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/TTS/tts/layers/generic/transformer.py b/TTS/tts/layers/generic/transformer.py index 9e6b69ac..12f0bbb0 100644 --- a/TTS/tts/layers/generic/transformer.py +++ b/TTS/tts/layers/generic/transformer.py @@ -15,17 +15,19 @@ class FFTransformer(nn.Module): self.norm1 = nn.LayerNorm(in_out_channels) self.norm2 = nn.LayerNorm(in_out_channels) - self.dropout = nn.Dropout(dropout_p) + self.dropout1 = nn.Dropout(dropout_p) + self.dropout2 = nn.Dropout(dropout_p) def forward(self, src, src_mask=None, src_key_padding_mask=None): """😦 ugly looking with all the transposing""" src = src.permute(2, 0, 1) src2, enc_align = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask) + src = src + self.dropout1(src2) src = self.norm1(src + src2) # T x B x D -> B x D x T src = src.permute(1, 2, 0) src2 = self.conv2(F.relu(self.conv1(src))) - src2 = self.dropout(src2) + src2 = self.dropout2(src2) src = src + src2 src = src.transpose(1, 2) src = self.norm2(src) @@ -52,8 +54,8 @@ class FFTransformerBlock(nn.Module): """ TODO: handle multi-speaker Shapes: - x: [B, C, T] - mask: [B, 1, T] or [B, T] + - x: :math:`[B, C, T]` + - mask: :math:`[B, 1, T] or [B, T]` """ if mask is not None and mask.ndim == 3: mask = mask.squeeze(1) @@ -65,3 +67,19 @@ class FFTransformerBlock(nn.Module): alignments.append(align.unsqueeze(1)) alignments = torch.cat(alignments, 1) return x + + +class FFTDurationPredictor: + def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): + self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p) + self.proj = nn.Linear(in_channels, 1) + + def forward(self, x, mask=None, g=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - mask: :math:`[B, 1, T]` + """ + x = self.fft(x, mask=mask) + x = self.proj(x) + return x