mirror of https://github.com/coqui-ai/TTS.git
Add transformer block to UNet
This commit is contained in:
parent
5fd7ea93ea
commit
f15230bb67
|
@ -1,7 +1,8 @@
|
|||
import math
|
||||
from einops import pack
|
||||
from einops import pack, rearrange
|
||||
import torch
|
||||
from torch import nn
|
||||
import conformer
|
||||
|
||||
|
||||
class PositionalEncoding(torch.nn.Module):
|
||||
|
@ -71,6 +72,40 @@ class Upsample1D(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class ConformerBlock(conformer.ConformerBlock):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_head: int = 64,
|
||||
heads: int = 8,
|
||||
ff_mult: int = 4,
|
||||
conv_expansion_factor: int = 2,
|
||||
conv_kernel_size: int = 31,
|
||||
attn_dropout: float = 0.,
|
||||
ff_dropout: float = 0.,
|
||||
conv_dropout: float = 0.,
|
||||
conv_causal: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
ff_mult=ff_mult,
|
||||
conv_expansion_factor=conv_expansion_factor,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
attn_dropout=attn_dropout,
|
||||
ff_dropout=ff_dropout,
|
||||
conv_dropout=conv_dropout,
|
||||
conv_causal=conv_causal,
|
||||
)
|
||||
|
||||
def forward(self, x, mask,):
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
mask = rearrange(mask, "b 1 t -> b t")
|
||||
output = super().forward(x=x, mask=mask.bool())
|
||||
return rearrange(output, "b t c -> b c t")
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
|
@ -80,6 +115,12 @@ class UNet(nn.Module):
|
|||
model_channels: int,
|
||||
out_channels: int,
|
||||
num_blocks: int,
|
||||
transformer_num_heads: int = 4,
|
||||
transformer_dim_head: int = 64,
|
||||
transformer_ff_mult: int = 1,
|
||||
transformer_conv_expansion_factor: int = 2,
|
||||
transformer_conv_kernel_size: int = 31,
|
||||
transformer_dropout: float = 0.05,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
@ -107,6 +148,18 @@ class UNet(nn.Module):
|
|||
)
|
||||
)
|
||||
|
||||
block.append(
|
||||
self._create_transformer_block(
|
||||
block_out_channels,
|
||||
dim_head=transformer_dim_head,
|
||||
num_heads=transformer_num_heads,
|
||||
ff_mult=transformer_ff_mult,
|
||||
conv_expansion_factor=transformer_conv_expansion_factor,
|
||||
conv_kernel_size=transformer_conv_kernel_size,
|
||||
dropout=transformer_dropout,
|
||||
)
|
||||
)
|
||||
|
||||
if level != num_blocks - 1:
|
||||
block.append(Downsample1D(block_out_channels))
|
||||
else:
|
||||
|
@ -116,6 +169,30 @@ class UNet(nn.Module):
|
|||
self.input_blocks.append(block)
|
||||
|
||||
self.middle_blocks = nn.ModuleList([])
|
||||
for i in range(2):
|
||||
block = nn.ModuleList([])
|
||||
|
||||
block.append(
|
||||
ResNetBlock1D(
|
||||
in_channels=block_out_channels,
|
||||
out_channels=block_out_channels,
|
||||
time_embed_channels=time_embed_channels
|
||||
)
|
||||
)
|
||||
|
||||
block.append(
|
||||
self._create_transformer_block(
|
||||
block_out_channels,
|
||||
dim_head=transformer_dim_head,
|
||||
num_heads=transformer_num_heads,
|
||||
ff_mult=transformer_ff_mult,
|
||||
conv_expansion_factor=transformer_conv_expansion_factor,
|
||||
conv_kernel_size=transformer_conv_kernel_size,
|
||||
dropout=transformer_dropout,
|
||||
)
|
||||
)
|
||||
|
||||
self.middle_blocks.append(block)
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
block_in_channels = block_out_channels * 2
|
||||
|
@ -131,6 +208,18 @@ class UNet(nn.Module):
|
|||
)
|
||||
)
|
||||
|
||||
block.append(
|
||||
self._create_transformer_block(
|
||||
block_out_channels,
|
||||
dim_head=transformer_dim_head,
|
||||
num_heads=transformer_num_heads,
|
||||
ff_mult=transformer_ff_mult,
|
||||
conv_expansion_factor=transformer_conv_expansion_factor,
|
||||
conv_kernel_size=transformer_conv_kernel_size,
|
||||
dropout=transformer_dropout,
|
||||
)
|
||||
)
|
||||
|
||||
if level != num_blocks - 1:
|
||||
block.append(Upsample1D(block_out_channels))
|
||||
else:
|
||||
|
@ -142,6 +231,29 @@ class UNet(nn.Module):
|
|||
self.conv_block = ConvBlock1D(model_channels, model_channels)
|
||||
self.conv = nn.Conv1d(model_channels, self.out_channels, 1)
|
||||
|
||||
def _create_transformer_block(
|
||||
self,
|
||||
dim,
|
||||
dim_head: int = 64,
|
||||
num_heads: int = 4,
|
||||
ff_mult: int = 1,
|
||||
conv_expansion_factor: int = 2,
|
||||
conv_kernel_size: int = 31,
|
||||
dropout: float = 0.05,
|
||||
):
|
||||
return ConformerBlock(
|
||||
dim=dim,
|
||||
dim_head=dim_head,
|
||||
heads=num_heads,
|
||||
ff_mult=ff_mult,
|
||||
conv_expansion_factor=conv_expansion_factor,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
attn_dropout=dropout,
|
||||
ff_dropout=dropout,
|
||||
conv_dropout=dropout,
|
||||
conv_causal=False,
|
||||
)
|
||||
|
||||
def forward(self, x_t, mean, mask, t):
|
||||
t = self.time_encoder(t)
|
||||
t = self.time_embed(t)
|
||||
|
@ -152,9 +264,11 @@ class UNet(nn.Module):
|
|||
mask_states = [mask]
|
||||
|
||||
for block in self.input_blocks:
|
||||
res_net_block, downsample = block
|
||||
res_net_block, transformer, downsample = block
|
||||
|
||||
x_t = res_net_block(x_t, mask, t)
|
||||
x_t = transformer(x_t, mask)
|
||||
|
||||
hidden_states.append(x_t)
|
||||
|
||||
if downsample is not None:
|
||||
|
@ -162,20 +276,23 @@ class UNet(nn.Module):
|
|||
mask = mask[:, :, ::2]
|
||||
mask_states.append(mask)
|
||||
|
||||
for _ in self.middle_blocks:
|
||||
pass
|
||||
for block in self.middle_blocks:
|
||||
res_net_block, transformer = block
|
||||
mask = mask_states[-1]
|
||||
x_t = res_net_block(x_t, mask, t)
|
||||
x_t = transformer(x_t, mask)
|
||||
|
||||
for block in self.output_blocks:
|
||||
res_net_block, upsample = block
|
||||
|
||||
res_net_block, transformer, upsample = block
|
||||
|
||||
x_t = pack([x_t, hidden_states.pop()], "b * t")[0]
|
||||
mask = mask_states.pop()
|
||||
x_t = res_net_block(x_t, mask, t)
|
||||
x_t = transformer(x_t, mask)
|
||||
|
||||
if upsample is not None:
|
||||
x_t = upsample(x_t * mask)
|
||||
|
||||
|
||||
output = self.conv_block(x_t)
|
||||
output = self.conv(x_t)
|
||||
|
||||
|
|
Loading…
Reference in New Issue