mirror of https://github.com/coqui-ai/TTS.git
162 lines
4.9 KiB
Python
162 lines
4.9 KiB
Python
import math
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
|
|
class Conv1d(nn.Conv1d):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
nn.init.orthogonal_(self.weight)
|
|
nn.init.zeros_(self.bias)
|
|
|
|
|
|
class NoiseLevelEncoding(nn.Module):
|
|
"""Noise level encoding applying same
|
|
encoding vector to all time steps. It is
|
|
different than the original implementation."""
|
|
def __init__(self, n_channels):
|
|
super().__init__()
|
|
self.n_channels = n_channels
|
|
self.length = n_channels // 2
|
|
assert n_channels % 2 == 0
|
|
|
|
enc = self.init_encoding(self.length)
|
|
self.register_buffer('enc', enc)
|
|
|
|
def forward(self, x, noise_level):
|
|
"""
|
|
Shapes:
|
|
x: B x C x T
|
|
noise_level: B
|
|
"""
|
|
return (x + self.encoding(noise_level)[:, :, None])
|
|
|
|
@staticmethod
|
|
def init_encoding(length):
|
|
div_by = torch.arange(length) / length
|
|
enc = torch.exp(-math.log(1e4) * div_by.unsqueeze(0))
|
|
return enc
|
|
|
|
def encoding(self, noise_level):
|
|
encoding = noise_level.unsqueeze(1) * self.enc
|
|
encoding = torch.cat(
|
|
[torch.sin(encoding), torch.cos(encoding)], dim=-1)
|
|
return encoding
|
|
|
|
|
|
class FiLM(nn.Module):
|
|
"""Feature-wise Linear Modulation. It combines information from
|
|
both noisy waveform and input mel-spectrogram. The FiLM module
|
|
produces both scale and bias vectors given inputs, which are
|
|
used in a UBlock for feature-wise affine transformation."""
|
|
|
|
def __init__(self, in_channels, out_channels):
|
|
super().__init__()
|
|
self.encoding = NoiseLevelEncoding(in_channels)
|
|
self.conv_in = Conv1d(in_channels, in_channels, 3, padding=1)
|
|
self.conv_out = Conv1d(in_channels, out_channels * 2, 3, padding=1)
|
|
self._init_parameters()
|
|
|
|
def _init_parameters(self):
|
|
nn.init.orthogonal_(self.conv_in.weight)
|
|
nn.init.orthogonal_(self.conv_out.weight)
|
|
|
|
def forward(self, x, noise_scale):
|
|
x = self.conv_in(x)
|
|
x = F.leaky_relu(x, 0.2)
|
|
x = self.encoding(x, noise_scale)
|
|
shift, scale = torch.chunk(self.conv_out(x), 2, dim=1)
|
|
return shift, scale
|
|
|
|
|
|
@torch.jit.script
|
|
def shif_and_scale(x, scale, shift):
|
|
o = shift + scale * x
|
|
return o
|
|
|
|
|
|
class UBlock(nn.Module):
|
|
def __init__(self, in_channels, hid_channels, upsample_factor, dilations):
|
|
super().__init__()
|
|
assert len(dilations) == 4
|
|
|
|
self.upsample_factor = upsample_factor
|
|
self.shortcut_conv = Conv1d(in_channels, hid_channels, 1)
|
|
self.main_block1 = nn.ModuleList([
|
|
Conv1d(in_channels,
|
|
hid_channels,
|
|
3,
|
|
dilation=dilations[0],
|
|
padding=dilations[0]),
|
|
Conv1d(hid_channels,
|
|
hid_channels,
|
|
3,
|
|
dilation=dilations[1],
|
|
padding=dilations[1])
|
|
])
|
|
self.main_block2 = nn.ModuleList([
|
|
Conv1d(hid_channels,
|
|
hid_channels,
|
|
3,
|
|
dilation=dilations[2],
|
|
padding=dilations[2]),
|
|
Conv1d(hid_channels,
|
|
hid_channels,
|
|
3,
|
|
dilation=dilations[3],
|
|
padding=dilations[3])
|
|
])
|
|
|
|
def forward(self, x, shift, scale):
|
|
upsample_size = x.shape[-1] * self.upsample_factor
|
|
x = F.interpolate(x, size=upsample_size)
|
|
res = self.shortcut_conv(x)
|
|
|
|
o = F.leaky_relu(x, 0.2)
|
|
o = self.main_block1[0](o)
|
|
o = shif_and_scale(o, scale, shift)
|
|
o = F.leaky_relu(o, 0.2)
|
|
o = self.main_block1[1](o)
|
|
|
|
o = o + res
|
|
res = o
|
|
|
|
o = shif_and_scale(o, scale, shift)
|
|
o = F.leaky_relu(o, 0.2)
|
|
o = self.main_block2[0](o)
|
|
o = shif_and_scale(o, scale, shift)
|
|
o = F.leaky_relu(o, 0.2)
|
|
o = self.main_block2[1](o)
|
|
|
|
o = o + res
|
|
return o
|
|
|
|
|
|
class DBlock(nn.Module):
|
|
def __init__(self, in_channels, hid_channels, downsample_factor):
|
|
super().__init__()
|
|
self.downsample_factor = downsample_factor
|
|
self.res_conv = Conv1d(in_channels, hid_channels, 1)
|
|
self.main_convs = nn.ModuleList([
|
|
Conv1d(in_channels, hid_channels, 3, dilation=1, padding=1),
|
|
Conv1d(hid_channels, hid_channels, 3, dilation=2, padding=2),
|
|
Conv1d(hid_channels, hid_channels, 3, dilation=4, padding=4),
|
|
])
|
|
|
|
def forward(self, x):
|
|
size = x.shape[-1] // self.downsample_factor
|
|
|
|
res = self.res_conv(x)
|
|
res = F.interpolate(res, size=size)
|
|
|
|
o = F.interpolate(x, size=size)
|
|
for layer in self.main_convs:
|
|
o = F.leaky_relu(o, 0.2)
|
|
o = layer(o)
|
|
|
|
return o + res
|