Update `generic.FFTransformer`

This commit is contained in:
Eren Gölge 2021-09-03 13:27:56 +00:00
parent 2bf9e83c49
commit 4672889549
1 changed files with 22 additions and 4 deletions

View File

@ -15,17 +15,19 @@ class FFTransformer(nn.Module):
self.norm1 = nn.LayerNorm(in_out_channels) self.norm1 = nn.LayerNorm(in_out_channels)
self.norm2 = nn.LayerNorm(in_out_channels) self.norm2 = nn.LayerNorm(in_out_channels)
self.dropout = nn.Dropout(dropout_p) self.dropout1 = nn.Dropout(dropout_p)
self.dropout2 = nn.Dropout(dropout_p)
def forward(self, src, src_mask=None, src_key_padding_mask=None): def forward(self, src, src_mask=None, src_key_padding_mask=None):
"""😦 ugly looking with all the transposing""" """😦 ugly looking with all the transposing"""
src = src.permute(2, 0, 1) src = src.permute(2, 0, 1)
src2, enc_align = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask) src2, enc_align = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
src = src + self.dropout1(src2)
src = self.norm1(src + src2) src = self.norm1(src + src2)
# T x B x D -> B x D x T # T x B x D -> B x D x T
src = src.permute(1, 2, 0) src = src.permute(1, 2, 0)
src2 = self.conv2(F.relu(self.conv1(src))) src2 = self.conv2(F.relu(self.conv1(src)))
src2 = self.dropout(src2) src2 = self.dropout2(src2)
src = src + src2 src = src + src2
src = src.transpose(1, 2) src = src.transpose(1, 2)
src = self.norm2(src) src = self.norm2(src)
@ -52,8 +54,8 @@ class FFTransformerBlock(nn.Module):
""" """
TODO: handle multi-speaker TODO: handle multi-speaker
Shapes: Shapes:
x: [B, C, T] - x: :math:`[B, C, T]`
mask: [B, 1, T] or [B, T] - mask: :math:`[B, 1, T] or [B, T]`
""" """
if mask is not None and mask.ndim == 3: if mask is not None and mask.ndim == 3:
mask = mask.squeeze(1) mask = mask.squeeze(1)
@ -65,3 +67,19 @@ class FFTransformerBlock(nn.Module):
alignments.append(align.unsqueeze(1)) alignments.append(align.unsqueeze(1))
alignments = torch.cat(alignments, 1) alignments = torch.cat(alignments, 1)
return x return x
class FFTDurationPredictor:
def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None):
self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p)
self.proj = nn.Linear(in_channels, 1)
def forward(self, x, mask=None, g=None):
"""
Shapes:
- x: :math:`[B, C, T]`
- mask: :math:`[B, 1, T]`
"""
x = self.fft(x, mask=mask)
x = self.proj(x)
return x