mirror of https://github.com/coqui-ai/TTS.git
Add transformer block to UNet
This commit is contained in:
parent
5fd7ea93ea
commit
f15230bb67
|
@ -1,7 +1,8 @@
|
||||||
import math
|
import math
|
||||||
from einops import pack
|
from einops import pack, rearrange
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import conformer
|
||||||
|
|
||||||
|
|
||||||
class PositionalEncoding(torch.nn.Module):
|
class PositionalEncoding(torch.nn.Module):
|
||||||
|
@ -71,6 +72,40 @@ class Upsample1D(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.conv(x)
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ConformerBlock(conformer.ConformerBlock):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
dim_head: int = 64,
|
||||||
|
heads: int = 8,
|
||||||
|
ff_mult: int = 4,
|
||||||
|
conv_expansion_factor: int = 2,
|
||||||
|
conv_kernel_size: int = 31,
|
||||||
|
attn_dropout: float = 0.,
|
||||||
|
ff_dropout: float = 0.,
|
||||||
|
conv_dropout: float = 0.,
|
||||||
|
conv_causal: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
dim=dim,
|
||||||
|
dim_head=dim_head,
|
||||||
|
heads=heads,
|
||||||
|
ff_mult=ff_mult,
|
||||||
|
conv_expansion_factor=conv_expansion_factor,
|
||||||
|
conv_kernel_size=conv_kernel_size,
|
||||||
|
attn_dropout=attn_dropout,
|
||||||
|
ff_dropout=ff_dropout,
|
||||||
|
conv_dropout=conv_dropout,
|
||||||
|
conv_causal=conv_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, mask,):
|
||||||
|
x = rearrange(x, "b c t -> b t c")
|
||||||
|
mask = rearrange(mask, "b 1 t -> b t")
|
||||||
|
output = super().forward(x=x, mask=mask.bool())
|
||||||
|
return rearrange(output, "b t c -> b c t")
|
||||||
|
|
||||||
|
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
|
@ -80,6 +115,12 @@ class UNet(nn.Module):
|
||||||
model_channels: int,
|
model_channels: int,
|
||||||
out_channels: int,
|
out_channels: int,
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
|
transformer_num_heads: int = 4,
|
||||||
|
transformer_dim_head: int = 64,
|
||||||
|
transformer_ff_mult: int = 1,
|
||||||
|
transformer_conv_expansion_factor: int = 2,
|
||||||
|
transformer_conv_kernel_size: int = 31,
|
||||||
|
transformer_dropout: float = 0.05,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
@ -107,6 +148,18 @@ class UNet(nn.Module):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
block.append(
|
||||||
|
self._create_transformer_block(
|
||||||
|
block_out_channels,
|
||||||
|
dim_head=transformer_dim_head,
|
||||||
|
num_heads=transformer_num_heads,
|
||||||
|
ff_mult=transformer_ff_mult,
|
||||||
|
conv_expansion_factor=transformer_conv_expansion_factor,
|
||||||
|
conv_kernel_size=transformer_conv_kernel_size,
|
||||||
|
dropout=transformer_dropout,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if level != num_blocks - 1:
|
if level != num_blocks - 1:
|
||||||
block.append(Downsample1D(block_out_channels))
|
block.append(Downsample1D(block_out_channels))
|
||||||
else:
|
else:
|
||||||
|
@ -116,6 +169,30 @@ class UNet(nn.Module):
|
||||||
self.input_blocks.append(block)
|
self.input_blocks.append(block)
|
||||||
|
|
||||||
self.middle_blocks = nn.ModuleList([])
|
self.middle_blocks = nn.ModuleList([])
|
||||||
|
for i in range(2):
|
||||||
|
block = nn.ModuleList([])
|
||||||
|
|
||||||
|
block.append(
|
||||||
|
ResNetBlock1D(
|
||||||
|
in_channels=block_out_channels,
|
||||||
|
out_channels=block_out_channels,
|
||||||
|
time_embed_channels=time_embed_channels
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
block.append(
|
||||||
|
self._create_transformer_block(
|
||||||
|
block_out_channels,
|
||||||
|
dim_head=transformer_dim_head,
|
||||||
|
num_heads=transformer_num_heads,
|
||||||
|
ff_mult=transformer_ff_mult,
|
||||||
|
conv_expansion_factor=transformer_conv_expansion_factor,
|
||||||
|
conv_kernel_size=transformer_conv_kernel_size,
|
||||||
|
dropout=transformer_dropout,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.middle_blocks.append(block)
|
||||||
|
|
||||||
self.output_blocks = nn.ModuleList([])
|
self.output_blocks = nn.ModuleList([])
|
||||||
block_in_channels = block_out_channels * 2
|
block_in_channels = block_out_channels * 2
|
||||||
|
@ -131,6 +208,18 @@ class UNet(nn.Module):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
block.append(
|
||||||
|
self._create_transformer_block(
|
||||||
|
block_out_channels,
|
||||||
|
dim_head=transformer_dim_head,
|
||||||
|
num_heads=transformer_num_heads,
|
||||||
|
ff_mult=transformer_ff_mult,
|
||||||
|
conv_expansion_factor=transformer_conv_expansion_factor,
|
||||||
|
conv_kernel_size=transformer_conv_kernel_size,
|
||||||
|
dropout=transformer_dropout,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if level != num_blocks - 1:
|
if level != num_blocks - 1:
|
||||||
block.append(Upsample1D(block_out_channels))
|
block.append(Upsample1D(block_out_channels))
|
||||||
else:
|
else:
|
||||||
|
@ -142,6 +231,29 @@ class UNet(nn.Module):
|
||||||
self.conv_block = ConvBlock1D(model_channels, model_channels)
|
self.conv_block = ConvBlock1D(model_channels, model_channels)
|
||||||
self.conv = nn.Conv1d(model_channels, self.out_channels, 1)
|
self.conv = nn.Conv1d(model_channels, self.out_channels, 1)
|
||||||
|
|
||||||
|
def _create_transformer_block(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_head: int = 64,
|
||||||
|
num_heads: int = 4,
|
||||||
|
ff_mult: int = 1,
|
||||||
|
conv_expansion_factor: int = 2,
|
||||||
|
conv_kernel_size: int = 31,
|
||||||
|
dropout: float = 0.05,
|
||||||
|
):
|
||||||
|
return ConformerBlock(
|
||||||
|
dim=dim,
|
||||||
|
dim_head=dim_head,
|
||||||
|
heads=num_heads,
|
||||||
|
ff_mult=ff_mult,
|
||||||
|
conv_expansion_factor=conv_expansion_factor,
|
||||||
|
conv_kernel_size=conv_kernel_size,
|
||||||
|
attn_dropout=dropout,
|
||||||
|
ff_dropout=dropout,
|
||||||
|
conv_dropout=dropout,
|
||||||
|
conv_causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x_t, mean, mask, t):
|
def forward(self, x_t, mean, mask, t):
|
||||||
t = self.time_encoder(t)
|
t = self.time_encoder(t)
|
||||||
t = self.time_embed(t)
|
t = self.time_embed(t)
|
||||||
|
@ -152,9 +264,11 @@ class UNet(nn.Module):
|
||||||
mask_states = [mask]
|
mask_states = [mask]
|
||||||
|
|
||||||
for block in self.input_blocks:
|
for block in self.input_blocks:
|
||||||
res_net_block, downsample = block
|
res_net_block, transformer, downsample = block
|
||||||
|
|
||||||
x_t = res_net_block(x_t, mask, t)
|
x_t = res_net_block(x_t, mask, t)
|
||||||
|
x_t = transformer(x_t, mask)
|
||||||
|
|
||||||
hidden_states.append(x_t)
|
hidden_states.append(x_t)
|
||||||
|
|
||||||
if downsample is not None:
|
if downsample is not None:
|
||||||
|
@ -162,20 +276,23 @@ class UNet(nn.Module):
|
||||||
mask = mask[:, :, ::2]
|
mask = mask[:, :, ::2]
|
||||||
mask_states.append(mask)
|
mask_states.append(mask)
|
||||||
|
|
||||||
for _ in self.middle_blocks:
|
for block in self.middle_blocks:
|
||||||
pass
|
res_net_block, transformer = block
|
||||||
|
mask = mask_states[-1]
|
||||||
|
x_t = res_net_block(x_t, mask, t)
|
||||||
|
x_t = transformer(x_t, mask)
|
||||||
|
|
||||||
for block in self.output_blocks:
|
for block in self.output_blocks:
|
||||||
res_net_block, upsample = block
|
res_net_block, transformer, upsample = block
|
||||||
|
|
||||||
x_t = pack([x_t, hidden_states.pop()], "b * t")[0]
|
x_t = pack([x_t, hidden_states.pop()], "b * t")[0]
|
||||||
mask = mask_states.pop()
|
mask = mask_states.pop()
|
||||||
x_t = res_net_block(x_t, mask, t)
|
x_t = res_net_block(x_t, mask, t)
|
||||||
|
x_t = transformer(x_t, mask)
|
||||||
|
|
||||||
if upsample is not None:
|
if upsample is not None:
|
||||||
x_t = upsample(x_t * mask)
|
x_t = upsample(x_t * mask)
|
||||||
|
|
||||||
|
|
||||||
output = self.conv_block(x_t)
|
output = self.conv_block(x_t)
|
||||||
output = self.conv(x_t)
|
output = self.conv(x_t)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue