mirror of https://github.com/coqui-ai/TTS.git
125 lines
4.9 KiB
Python
125 lines
4.9 KiB
Python
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
|
|
from ..layers.wavegrad import DBlock, FiLM, UBlock
|
|
|
|
|
|
class Wavegrad(nn.Module):
|
|
# pylint: disable=dangerous-default-value
|
|
def __init__(self,
|
|
in_channels=80,
|
|
out_channels=1,
|
|
x_conv_channels=32,
|
|
c_conv_channels=768,
|
|
dblock_out_channels=[128, 128, 256, 512],
|
|
ublock_out_channels=[512, 512, 256, 128, 128],
|
|
upsample_factors=[5, 5, 3, 2, 2],
|
|
upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8],
|
|
[1, 2, 4, 8], [1, 2, 4, 8]]):
|
|
super().__init__()
|
|
|
|
assert len(upsample_factors) == len(upsample_dilations)
|
|
assert len(upsample_factors) == len(ublock_out_channels)
|
|
|
|
# setup up-down sampling parameters
|
|
self.hop_length = np.prod(upsample_factors)
|
|
self.upsample_factors = upsample_factors
|
|
self.downsample_factors = upsample_factors[::-1][:-1]
|
|
|
|
### define DBlocks, FiLM layers ###
|
|
self.dblocks = nn.ModuleList([
|
|
nn.Conv1d(out_channels, x_conv_channels, 5, padding=2),
|
|
])
|
|
ic = x_conv_channels
|
|
self.films = nn.ModuleList([])
|
|
for oc, df in zip(dblock_out_channels, self.downsample_factors):
|
|
# print('dblock(', ic, ', ', oc, ', ', df, ")")
|
|
layer = DBlock(ic, oc, df)
|
|
self.dblocks.append(layer)
|
|
|
|
# print('film(', ic, ', ', oc,")")
|
|
layer = FiLM(ic, oc)
|
|
self.films.append(layer)
|
|
ic = oc
|
|
# last FiLM block
|
|
# print('film(', ic, ', ', dblock_out_channels[-1],")")
|
|
self.films.append(FiLM(ic, dblock_out_channels[-1]))
|
|
|
|
### define UBlocks ###
|
|
self.c_conv = nn.Conv1d(in_channels, c_conv_channels, 3, padding=1)
|
|
self.ublocks = nn.ModuleList([])
|
|
ic = c_conv_channels
|
|
for idx, (oc, uf) in enumerate(zip(ublock_out_channels, self.upsample_factors)):
|
|
# print('ublock(', ic, ', ', oc, ', ', uf, ")")
|
|
layer = UBlock(ic, oc, uf, upsample_dilations[idx])
|
|
self.ublocks.append(layer)
|
|
ic = oc
|
|
|
|
# define last layer
|
|
# print(ic, 'last_conv--', out_channels)
|
|
self.last_conv = nn.Conv1d(ic, out_channels, 3, padding=1)
|
|
|
|
# inference time noise schedule params
|
|
self.S = 1000
|
|
self.init_noise_schedule(self.S)
|
|
|
|
|
|
def init_noise_schedule(self, num_iter, min_val=1e-6, max_val=0.01):
|
|
"""compute noise schedule parameters"""
|
|
device = self.last_conv.weight.device
|
|
beta = torch.linspace(min_val, max_val, num_iter).to(device)
|
|
alpha = 1 - beta
|
|
alpha_cum = alpha.cumprod(dim=0)
|
|
noise_level = torch.cat([torch.FloatTensor([1]).to(device), alpha_cum ** 0.5])
|
|
|
|
self.register_buffer('beta', beta)
|
|
self.register_buffer('alpha', alpha)
|
|
self.register_buffer('alpha_cum', alpha_cum)
|
|
self.register_buffer('noise_level', noise_level)
|
|
|
|
def compute_noisy_x(self, x):
|
|
B = x.shape[0]
|
|
if len(x.shape) == 3:
|
|
x = x.squeeze(1)
|
|
s = torch.randint(1, self.S + 1, [B]).to(x).long()
|
|
l_a, l_b = self.noise_level[s-1], self.noise_level[s]
|
|
noise_scale = l_a + torch.rand(B).to(x) * (l_b - l_a)
|
|
noise_scale = noise_scale.unsqueeze(1)
|
|
noise = torch.randn_like(x)
|
|
noisy_x = noise_scale * x + (1.0 - noise_scale**2)**0.5 * noise
|
|
return noise.unsqueeze(1), noisy_x.unsqueeze(1), noise_scale[:, 0]
|
|
|
|
def forward(self, x, c, noise_scale):
|
|
assert len(c.shape) == 3 # B, C, T
|
|
assert len(x.shape) == 3 # B, 1, T
|
|
o = x
|
|
shift_and_scales = []
|
|
for film, dblock in zip(self.films, self.dblocks):
|
|
o = dblock(o)
|
|
shift_and_scales.append(film(o, noise_scale))
|
|
|
|
o = self.c_conv(c)
|
|
for ublock, (film_shift, film_scale) in zip(self.ublocks,
|
|
reversed(shift_and_scales)):
|
|
o = ublock(o, film_shift, film_scale)
|
|
o = self.last_conv(o)
|
|
return o
|
|
|
|
def inference(self, c):
|
|
with torch.no_grad():
|
|
x = torch.randn(c.shape[0], 1, self.hop_length * c.shape[-1]).to(c)
|
|
noise_scale = (self.alpha_cum**0.5).unsqueeze(1).to(c)
|
|
for n in range(len(self.alpha) - 1, -1, -1):
|
|
c1 = 1 / self.alpha[n]**0.5
|
|
c2 = (1 - self.alpha[n]) / (1 - self.alpha_cum[n])**0.5
|
|
x = c1 * (x -
|
|
c2 * self.forward(x, c, noise_scale[n]).squeeze(1))
|
|
if n > 0:
|
|
noise = torch.randn_like(x)
|
|
sigma = ((1.0 - self.alpha_cum[n - 1]) /
|
|
(1.0 - self.alpha_cum[n]) * self.beta[n])**0.5
|
|
x += sigma * noise
|
|
x = torch.clamp(x, -1.0, 1.0)
|
|
return x
|