coqui-tts/TTS/vocoder/layers/wavegrad.py

162 lines
4.9 KiB
Python

import math
import torch
from torch import nn
from torch.nn import functional as F
class Conv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.reset_parameters()
def reset_parameters(self):
nn.init.orthogonal_(self.weight)
nn.init.zeros_(self.bias)
class NoiseLevelEncoding(nn.Module):
"""Noise level encoding applying same
encoding vector to all time steps. It is
different than the original implementation."""
def __init__(self, n_channels):
super().__init__()
self.n_channels = n_channels
self.length = n_channels // 2
assert n_channels % 2 == 0
enc = self.init_encoding(self.length)
self.register_buffer('enc', enc)
def forward(self, x, noise_level):
"""
Shapes:
x: B x C x T
noise_level: B
"""
return (x + self.encoding(noise_level)[:, :, None])
@staticmethod
def init_encoding(length):
div_by = torch.arange(length) / length
enc = torch.exp(-math.log(1e4) * div_by.unsqueeze(0))
return enc
def encoding(self, noise_level):
encoding = noise_level.unsqueeze(1) * self.enc
encoding = torch.cat(
[torch.sin(encoding), torch.cos(encoding)], dim=-1)
return encoding
class FiLM(nn.Module):
"""Feature-wise Linear Modulation. It combines information from
both noisy waveform and input mel-spectrogram. The FiLM module
produces both scale and bias vectors given inputs, which are
used in a UBlock for feature-wise affine transformation."""
def __init__(self, in_channels, out_channels):
super().__init__()
self.encoding = NoiseLevelEncoding(in_channels)
self.conv_in = Conv1d(in_channels, in_channels, 3, padding=1)
self.conv_out = Conv1d(in_channels, out_channels * 2, 3, padding=1)
self._init_parameters()
def _init_parameters(self):
nn.init.orthogonal_(self.conv_in.weight)
nn.init.orthogonal_(self.conv_out.weight)
def forward(self, x, noise_scale):
x = self.conv_in(x)
x = F.leaky_relu(x, 0.2)
x = self.encoding(x, noise_scale)
shift, scale = torch.chunk(self.conv_out(x), 2, dim=1)
return shift, scale
@torch.jit.script
def shif_and_scale(x, scale, shift):
o = shift + scale * x
return o
class UBlock(nn.Module):
def __init__(self, in_channels, hid_channels, upsample_factor, dilations):
super().__init__()
assert len(dilations) == 4
self.upsample_factor = upsample_factor
self.shortcut_conv = Conv1d(in_channels, hid_channels, 1)
self.main_block1 = nn.ModuleList([
Conv1d(in_channels,
hid_channels,
3,
dilation=dilations[0],
padding=dilations[0]),
Conv1d(hid_channels,
hid_channels,
3,
dilation=dilations[1],
padding=dilations[1])
])
self.main_block2 = nn.ModuleList([
Conv1d(hid_channels,
hid_channels,
3,
dilation=dilations[2],
padding=dilations[2]),
Conv1d(hid_channels,
hid_channels,
3,
dilation=dilations[3],
padding=dilations[3])
])
def forward(self, x, shift, scale):
upsample_size = x.shape[-1] * self.upsample_factor
x = F.interpolate(x, size=upsample_size)
res = self.shortcut_conv(x)
o = F.leaky_relu(x, 0.2)
o = self.main_block1[0](o)
o = shif_and_scale(o, scale, shift)
o = F.leaky_relu(o, 0.2)
o = self.main_block1[1](o)
o = o + res
res = o
o = shif_and_scale(o, scale, shift)
o = F.leaky_relu(o, 0.2)
o = self.main_block2[0](o)
o = shif_and_scale(o, scale, shift)
o = F.leaky_relu(o, 0.2)
o = self.main_block2[1](o)
o = o + res
return o
class DBlock(nn.Module):
def __init__(self, in_channels, hid_channels, downsample_factor):
super().__init__()
self.downsample_factor = downsample_factor
self.res_conv = Conv1d(in_channels, hid_channels, 1)
self.main_convs = nn.ModuleList([
Conv1d(in_channels, hid_channels, 3, dilation=1, padding=1),
Conv1d(hid_channels, hid_channels, 3, dilation=2, padding=2),
Conv1d(hid_channels, hid_channels, 3, dilation=4, padding=4),
])
def forward(self, x):
size = x.shape[-1] // self.downsample_factor
res = self.res_conv(x)
res = F.interpolate(res, size=size)
o = F.interpolate(x, size=size)
for layer in self.main_convs:
o = F.leaky_relu(o, 0.2)
o = layer(o)
return o + res